slope.core
1from pathlib import Path 2import os 3import json 4from typing import ( 5 Callable, 6 NamedTuple, 7 Dict, 8 Hashable, 9 List, 10 Any, 11 Iterable, 12 Sequence, 13 Iterator, 14 Type, 15 Tuple, 16 Optional, 17 Union, 18 Dict, 19 Set, 20 DefaultDict, 21 Final, 22) 23import weakref 24import types 25from contextlib import contextmanager, ContextDecorator 26import itertools 27import weakref 28import operator as operator_py 29import numpy as np 30import math 31import inspect 32from functools import partial, lru_cache 33import mmap 34import traceback 35import importlib 36import time 37import cProfile 38import pstats 39 40# ================================= 41# Utils 42# ================================= 43 44 45class Timing(ContextDecorator): 46 def __init__(self, prefix="", on_exit=None, enabled=True): 47 self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled 48 49 def __enter__(self): 50 self.st = time.perf_counter_ns() 51 52 def __exit__(self, *exc): 53 self.et = time.perf_counter_ns() - self.st 54 if self.enabled: 55 print(f"{self.prefix}{self.et*1e-6:6.2f} ms" + (self.on_exit(self.et) if self.on_exit else "")) 56 57 58def colored(st, color: Optional[str], background=False): 59 return ( 60 f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" 61 if color is not None 62 else st 63 ) # replace the termcolor library with one line # noqa: E501 64 65 66def _format_fcn(fcn): 67 return f"{fcn[0]}:{fcn[1]}:{fcn[2]}" 68 69 70class Profiling(ContextDecorator): 71 def __init__(self, enabled=True, sort="cumtime", frac=0.2, fn=None, ts=1): 72 self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3 / ts 73 74 def __enter__(self): 75 self.pr = cProfile.Profile() 76 if self.enabled: 77 self.pr.enable() 78 79 def __exit__(self, *exc): 80 if self.enabled: 81 self.pr.disable() 82 if self.fn: 83 self.pr.dump_stats(self.fn) 84 stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort) 85 for fcn in stats.fcn_list[0 : int(len(stats.fcn_list) * self.frac)]: # type: ignore[attr-defined] 86 (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined] 87 scallers = sorted(callers.items(), key=lambda x: -x[1][2]) 88 print( 89 f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms", 90 colored(_format_fcn(fcn), "yellow") + " " * (50 - len(_format_fcn(fcn))), 91 colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else "", 92 ) 93 94 95def dblog(*msg, enable=True): 96 if enable: 97 print(*msg) 98 99 100def unzip2(pairs) -> Tuple[List[Any], List[Any]]: 101 lst1, lst2 = [], [] 102 for i1, i2 in pairs: 103 lst1 += [i1] 104 lst2 += [i2] 105 return lst1, lst2 106 107 108def list_map(f: Callable, *xs: Iterable) -> List[Any]: 109 return list(map(f, *xs)) 110 111 112def list_zip(*args: List[Any]) -> List[Any]: 113 fst, *rest = args = list_map(list, args) 114 n = len(fst) 115 for arg in rest: 116 assert len(arg) == n 117 return list(zip(*args)) 118 119 120def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]: 121 assert not len(lst) % 2 122 return split_list(lst, len(lst) // 2) 123 124 125def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]: 126 l1, l2 = iter(l1), iter(l2) 127 out = [next(l2) if b else next(l1) for b in which] 128 assert next(l1, None) is next(l2, None) is None 129 return out 130 131 132def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]: 133 assert 0 <= n <= len(lst) 134 return lst[:n], lst[n:] 135 136 137def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]: 138 assert 0 <= n <= len(lst) 139 return lst[:n], lst[n:] 140 141 142def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]: 143 assert len(bs) == len(l) 144 lists = lst1, lst2 = [], [] 145 for b, x in zip(bs, l): 146 lists[b].append(x) 147 return lst1, lst2 148 149 150def lru_cache_verbose( 151 maxsize: int = 100, 152 typed: bool = False, 153 tb_start: int = -12, 154 tb_end: int = -7, 155): 156 def decorator(fn: Callable): 157 @lru_cache(maxsize=maxsize, typed=typed) 158 def wrapper(*args, **kwargs) -> Callable: 159 return fn(*args, **kwargs) 160 161 def decorated_function(*args, **kwargs) -> Any: 162 result = wrapper(*args, **kwargs) 163 cache_info = wrapper.cache_info() 164 165 dblog( 166 f"{fn.__name__}.{cache_info} {args.__hash__()}", 167 enable=backend.LOG_LRU, 168 ) 169 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 170 dblog(f"{tb}", enable=backend.LOG_LRU) 171 172 return result 173 174 decorated_function.cache_info = wrapper.cache_info 175 decorated_function.fn = fn 176 return decorated_function 177 178 return decorator 179 180 181def cuda_is_available(): 182 try: 183 import subprocess 184 import platform 185 186 cmd = f"nvidia-smi{'.exe' if platform.system == 'Windows' else ''}" 187 result = subprocess.run([cmd], stdout=subprocess.PIPE) 188 output = result.stdout.decode("utf-8") 189 return True if "NVIDIA-SMI" in output else False 190 except FileNotFoundError: 191 return False 192 193 194class Hashed: 195 val: Any 196 197 def __init__(self, val): 198 self.val = val 199 200 def __hash__(self) -> int: 201 return hash((self.val,)) 202 203 def __eq__(self, other): 204 if isinstance(other, Hashed): 205 if isinstance(self.val, Tensor) and isinstance(other.val, Tensor): 206 # because Tensor.__eq__ already for Tensor.equal 207 return id(self.val) == id(other.val) 208 return self.val == other.val 209 return False 210 211 def __repr__(self): 212 return f"Hashed: {repr(self.val)}" 213 214 215# ================================= 216# Tensors 217# ================================= 218 219 220class DType(NamedTuple): 221 priority: int 222 itemsize: int 223 name: str 224 mlir: str 225 numpy: type 226 227 @property 228 def format_code(self): 229 return f"slope.{self.name}" 230 231 def __repr__(self): 232 return f"<DType: {self.name}>" 233 234 235class dtypes: 236 float32: Final[DType] = DType(4, 4, "float32", "f32", np.float32) 237 uint8: Final[DType] = DType(0, 1, "uint8", "u8", np.uint8) 238 int8: Final[DType] = DType(0, 1, "int8", "i8", np.int8) 239 bool: Final[DType] = DType(0, 1, "bool", "i1", bool) 240 int32: Final[DType] = DType(1, 4, "int32", "i32", np.int32) 241 int64: Final[DType] = DType(2, 8, "int64", "i64", np.int64) 242 uint64: Final[DType] = DType(2, 8, "uint64", "ui64", np.uint64) 243 float16: Final[DType] = DType(0, 2, "float16", "f16", np.float16) 244 half = float16 245 # bfloat16: Final[DType] = DType(0, 2, "bfloat16", "bf16", np.float16) 246 247 all_dtypes = (bool, float16, float32, int8, int32, int64, uint8, uint64) 248 name_dtype_map = {k.name: k for k in all_dtypes} 249 name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()} 250 mlir_dtype_map = {k.mlir: k for k in all_dtypes} 251 mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()} 252 253 @classmethod 254 def is_int(cls, dtype): 255 return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64) 256 257 @classmethod 258 def is_float(cls, dtype): 259 return dtype in (cls.float16, cls.bfloat16, cls.float32) 260 261 262class Device(NamedTuple): 263 name: str 264 idx: int 265 266 @property 267 def format_code(self): 268 return f"'{self.name}:{self.idx}'" 269 270 def __repr__(self): 271 return f"<Device: {self.format_code}>" 272 273 274class devices: 275 cpu: Final[Device] = Device("cpu", 0) 276 metal: Final[Device] = Device("metal", 0) 277 cuda0: Final[Device] = Device("cuda", 0) 278 # TODO: programmatically define this class attrs to support other setup 279 cuda = cuda0 280 all_devices = (cpu, metal, cuda0) 281 name_idx_device_map = {f"{k.name}:{k.idx}": k for k in all_devices} 282 name_idx_device_map_inv = {v: k for k, v in name_idx_device_map.items()} 283 284 285class TensorBuffer: 286 def __init__(self, val): 287 self.val = val 288 289 290class Tensor: 291 def __init__(self, val: TensorBuffer): 292 assert isinstance(val, TensorBuffer) 293 self.buf = val 294 295 @property 296 def symval(self): 297 return SymbolicTensor.like(self) 298 299 @property 300 def default_dtype(self): 301 return backend.default_dtype 302 303 def is_int(self) -> bool: 304 return self.dtype in ( 305 dtypes.int8, 306 dtypes.uint8, 307 dtypes.uint64, 308 dtypes.int32, 309 dtypes.int64, 310 ) 311 312 def is_float(self) -> bool: 313 return self.dtype in (dtypes.float16, dtypes.float32) 314 315 def is_unsigned(self) -> bool: 316 return self.dtype is dtypes.uint8 317 318 def to_bool(self): 319 return self.cast(dtypes.bool) 320 321 def short(self): 322 return self.cast(dtypes.int8) 323 324 def int(self): 325 return self.cast(dtypes.int32) 326 327 def long(self): 328 return self.cast(dtypes.int64) 329 330 def half(self): 331 return self.cast(dtypes.float16) 332 333 def float(self): 334 return self.cast(dtypes.float32) 335 336 def __getattr__(self, attr): 337 if attr in vars(backend.operator_set).keys(): 338 op = getattr(backend.operator_set, attr) 339 return partial(op, self) 340 elif attr in vars(backend.procedure_set).keys(): 341 procedure = getattr(backend.procedure_set, attr) 342 assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}" 343 return partial(procedure, self) 344 else: 345 return self.__getattribute__(attr) 346 347 def __getitem__(self, idx): 348 return self.getitem(idx) 349 350 def __setitem__(self, idx, item): 351 raise NotImplementedError 352 353 def str_short(self): 354 return f"<Tensor: shape={self.shape}, dtype={self.dtype}>" 355 356 __neg__ = lambda self: self.neg() 357 __add__ = lambda self, other: self.add(other) 358 __radd__ = lambda self, other: self.add(other) 359 __sub__ = lambda self, other: self.sub(other) 360 __rsub__ = lambda self, other: self.sub.func(other, self) 361 __mul__ = lambda self, other: self.mul(other) 362 __rmul__ = lambda self, other: self.mul(other) 363 __div__ = lambda self, other: self.div(other) 364 __rdiv__ = lambda self, other: self.div.func(other, self) 365 __truediv__ = __div__ 366 __truerdiv__ = __rdiv__ 367 __pow__ = lambda self, other: self.pow(other) 368 __rpow__ = lambda self, other: self.pow.func(other, self) 369 __matmul__ = lambda self, other: self.matmul(other) 370 __rmatmul__ = lambda self, other: self.matmul.func(other, self) 371 __invert__ = lambda self: self.invert() 372 __eq__ = lambda self, other: self.equal(other) 373 __ne__ = lambda self, other: self.not_equal(other) 374 __ge__ = lambda self, other: self.greater_equal(other) 375 __le__ = lambda self, other: self.less_equal(other) 376 __gt__ = lambda self, other: self.greater(other) 377 __lt__ = lambda self, other: self.less(other) 378 379 def __hash__(self): 380 return id(self.val) 381 382 val = property(lambda self: self.buf.val) 383 384 def size(self, i): 385 return self.shape[i] 386 387 @property 388 def dtype(self): 389 return backend.dtype_of(self) 390 391 @property 392 def device(self): 393 return backend.device_of(self) 394 395 def numpy(self, memmap=False): 396 return backend.numpy_of(self, memmap) 397 398 @property 399 def shape(self): 400 return backend.shape_of(self) 401 402 @property 403 def ndim(self): 404 return len(self.shape) 405 406 def numel(self): 407 return math.prod(self.shape) 408 409 def element_size(self): 410 return self.dtype.itemsize 411 412 def nbytes(self): 413 return self.numel() * self.element_size() 414 415 def __repr__(self): 416 return f"<Tensor: val=\n{self.numpy()}\nshape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}>" 417 418 419class SymbolicTensor(Tensor): 420 def __init__(self, shape, dtype, device): 421 assert isinstance(dtype, DType) 422 self._shape = tuple(int(i) for i in shape) 423 self._dtype = dtype 424 self._device = device 425 426 @property 427 def symval(self): 428 return self 429 430 @property 431 def val(self): 432 raise RuntimeError(f"SymbolicTensor actually has no val, from {trace_stack[-1]=}, ") 433 434 @property 435 def shape(self): 436 return self._shape 437 438 @property 439 def dtype(self): 440 return self._dtype 441 442 @property 443 def device(self): 444 return self._device 445 446 def like(self, **overrides): 447 shape = overrides.get("shape", self.shape) 448 dtype = overrides.get("dtype", self.dtype) 449 device = overrides.get("device", self.device) 450 return SymbolicTensor(shape, dtype, device) 451 452 def str_short(self): 453 return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]' 454 455 def __hash__(self): 456 return hash((self.shape, self.dtype)) 457 458 def __eq__(self, other): 459 if type(self) != type(other): 460 return False 461 return (self.shape == other.shape) and (self.dtype == other.dtype) 462 463 def __repr__(self): 464 return f"<SymbolicTensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device}>" 465 466 467# ================================= 468# Operator 469# ================================= 470 471 472class Operator: 473 def __init__(self, name, variadic_inputs=False, nary_outputs=False): 474 self.name = name 475 self.variadic_inputs = variadic_inputs 476 self.nary_outputs = nary_outputs 477 if self.variadic_inputs: 478 self.reorg_args = self.reorg_args_nary 479 480 def __hash__(self): 481 return hash(self.name) 482 483 def __eq__(self, other): 484 if not isinstance(other, Operator): 485 return False 486 return self.name == other.name 487 488 def args_fixer(self, *args, **params): 489 return args, params 490 491 def __call__(self, *args, **params): 492 args, params = self.reorg_args(args, params) 493 args, params = self.args_fixer(*args, **params) 494 ret = bind(self, *args, **params) 495 if not self.nary_outputs: 496 ret = ret[0] 497 return ret 498 499 def __repr__(self) -> str: 500 return f"<{self.name}>" 501 502 def typecheck(self, *args, **params): 503 raise NotImplementedError 504 505 def jvp(self, *args, **params): 506 raise NotImplementedError 507 508 def T(self, *args, **params): 509 raise NotImplementedError 510 511 def vmap(self, *args, **params): 512 raise NotImplementedError 513 514 def reorg_args(self, args, params): 515 sig = inspect.signature(self.typecheck) 516 args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"] 517 params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"] 518 519 if args: 520 if len(args) > len(args_strs): 521 args, rest = args[: len(args_strs)], args[len(args_strs) :] 522 if params_strs: 523 new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params} 524 params = {**new_params, **params} 525 else: 526 args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)]) 527 assert len(args) == len(args_strs) 528 return args, params 529 530 def reorg_args_nary(self, args, params): 531 return args, params 532 533 def partial_run(self, trace, tracers, **params): 534 tracers_in = [trace.instantiate_const(t) for t in tracers] 535 symvals_in = [t.symval for t in tracers_in] 536 symvals_out = self.typecheck(*symvals_in, **params) 537 tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out] 538 instruction = InstructionDraft( 539 self, 540 tracers_in, 541 params, 542 symvals_out, 543 list_map(weakref.ref, tracers_out), 544 ) 545 for t in tracers_out: 546 t.draft = instruction 547 return tracers_out 548 549 def partial_run_instruction(self, unks_in, instruction): 550 if any(unks_in): 551 instruction1 = None 552 instruction2 = Instruction( 553 instruction.op, 554 instruction.inputs, 555 instruction.params, 556 instruction.out_binders, 557 ) 558 unks_out = [True for i in instruction.out_binders] 559 res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)] 560 else: 561 instruction1 = instruction 562 instruction2 = None 563 unks_out = [False for i in instruction.out_binders] 564 res = None 565 566 return instruction1, instruction2, unks_out, res 567 568 569class MetaOperator(Operator): 570 def meta_impl(self, *args, **kwargs): 571 raise NotImplementedError 572 573 574class UnaryOperator(Operator): 575 def vmap(self, x, *, dim_size, vals_in, dims_in, **params): 576 (x,), (x_bdim,) = vals_in, dims_in 577 return [self(x, **params)], [x_bdim] 578 579 def typecheck(self, x, **params): 580 return [SymbolicTensor.like(x)] 581 582 def jvp(self, primals, tangents, **params): 583 (x,), (x_dot,) = primals, tangents 584 return [self(x, **params)], [self(x_dot, **params)] 585 586 587class BinaryOperator(Operator): 588 boolean_output = False 589 590 def args_fixer(self, x, w, **params): 591 if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal: 592 assert x.shape == w.shape 593 return (x, w), params 594 595 if type(x) in TraceTensor.PYTHON_TYPES: 596 x = backend.full(shape=(), fill_value=x, dtype=w.dtype) 597 elif type(w) in TraceTensor.PYTHON_TYPES: 598 w = backend.full(shape=(), fill_value=w, dtype=x.dtype) 599 600 shape_delta = x.ndim - w.ndim 601 if shape_delta > 0: 602 w = w.reshape((1,) * shape_delta + w.shape) 603 elif shape_delta < 0: 604 x = x.reshape((1,) * -shape_delta + x.shape) 605 606 shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)]) 607 if x.shape != shape_ret: 608 x = x.expand(shape_ret) 609 if w.shape != shape_ret: 610 w = w.expand(shape_ret) 611 612 if type(x) is Tensor and isinstance(w, TraceTensor): 613 x = w._trace.pure(x) 614 elif type(w) is Tensor and isinstance(x, TraceTensor): 615 w = x._trace.pure(w) 616 # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html 617 if x.dtype != w.dtype: 618 # {int, bool} -> float 619 if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype): 620 if dtypes.is_float(w.dtype): 621 x = x.cast(w.dtype) 622 elif dtypes.is_float(x.dtype): 623 w = w.cast(x.dtype) 624 # bool -> int 625 elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype): 626 if dtypes.is_int(w.dtype): 627 x = x.cast(w.dtype) 628 elif dtypes.is_int(x.dtype): 629 w = w.cast(x.dtype) 630 else: # TODO: fine-grained type promotions 631 raise NotImplementedError("No other type promotion rules") 632 633 return (x, w), params 634 635 def vmap(self, dim_size, vals_in, dims_in, **params): 636 (x, w), (x_bdim, w_bdim) = vals_in, dims_in 637 if x_bdim != w_bdim: 638 if x_bdim is None: 639 x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim) 640 x_bdim = w_bdim 641 else: 642 w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim) 643 return [self(x, w, **params)], [x_bdim] 644 645 def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]: 646 if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)): 647 raise TypeError 648 symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype) 649 symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype) 650 if x.dtype != y.dtype: 651 raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})") 652 if symx == symy: 653 return [symx] 654 shape_delta = len(symx.shape) - len(symy.shape) 655 if shape_delta > 0: 656 symy = symy.like(shape=(1,) * shape_delta + symy.shape) 657 elif shape_delta < 0: 658 symx = symx.like(shape=(1,) * -shape_delta + symx.shape) 659 if symx == symy: 660 return [symx] 661 else: 662 shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)]) 663 if symx.shape != shape_ret: 664 symx = symx.like(shape=shape_ret) 665 if symy.shape != shape_ret: 666 symy = symx.like(shape=shape_ret) 667 if symx != symy: 668 raise TypeError(f"symx ({symx}) != symy ({symy})") 669 return [symx] 670 671 def jvp(self, primals, tangents, **params): 672 (x, w), (x_dot, w_dot) = primals, tangents 673 return [self(x, w, **params)], [self(x_dot, w_dot, **params)] 674 675 def T(self, cotangents, x, w): 676 (gL_y,) = cotangents 677 if self.boolean_output: 678 gL_y = gL_y.cast(x.dtype) 679 if isinstance(x, UndefinedPrimal): 680 return [gL_y, NullCotangent] 681 elif isinstance(w, UndefinedPrimal): 682 return [NullCotangent, gL_y] 683 else: 684 raise ValueError 685 686 687class ReduceOperator(Operator): 688 def args_fixer(self, x, *, dim=None, keepdim=False): 689 if dim is None: 690 dim = tuple(range(x.ndim)) 691 elif isinstance(dim, int): 692 dim = (dim,) 693 dim = tuple(a if a >= 0 else a + len(x.shape) for a in dim) 694 return (x,), dict(dim=dim, keepdim=keepdim) 695 696 def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim): 697 (x,), (x_bdim,) = vals_in, dims_in 698 dim = tuple(a + (x_bdim <= a) for a in dim) 699 out_bdim = x_bdim - sum(a < x_bdim for a in dim) 700 return [self(x, dim=dim, keepdim=keepdim)], [out_bdim] 701 702 def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]: 703 dim = list(set([a + len(x.shape) if a < 0 else a for a in dim])) 704 if keepdim: 705 new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)] 706 else: 707 new_shape = [d for i, d in enumerate(x.shape) if i not in dim] 708 return [SymbolicTensor.like(x, shape=tuple(new_shape))] 709 710 711class InitOperator(Operator): 712 def vmap(self, dim_size, vals_in, dims_in, **params): 713 (x_bdim,) = dims_in 714 y = self(**params) 715 y = y.unsqueeze(x_bdim) 716 return [y], [x_bdim] 717 718 def jvp(self, primals, tangents, **params): 719 y = self(**params) 720 y_dot = NullCotangent(y.symval) 721 return [y], [y_dot] 722 723 def T(self, cotangents, **params): 724 return [NullCotangent(cotangents[0])] 725 726 727class ShapeOperator(Operator): 728 pass 729 730 731class GeneralReduceOperator(Operator): 732 pass 733 734 735class OperatorSet: 736 def __init__(self): 737 self.register("jit_op")(JitOp) 738 739 def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()): 740 def wrap(op_cls): 741 assert name not in vars(self) 742 op = op_cls(name, variadic_inputs, nary_outputs) 743 setattr(self, name, op) 744 for a in aliases: 745 setattr(self, a, op) 746 return op_cls 747 748 return wrap 749 750 751class ProcedureSet: 752 def register(self, aliases=()): 753 def wrap(f): 754 assert f.__name__ not in vars(self) 755 setattr(self, f.__name__, f) 756 for a in aliases: 757 setattr(self, a, f) 758 return f 759 760 return wrap 761 762 763class CodegenOutput(NamedTuple): 764 code_lines: List[str] 765 fn_defs: Dict[str, List[str]] 766 in_binders: List["ProgramEnvVar"] 767 outs: List["ProgramEnvVar"] 768 769 770class Backend: 771 LOG_LRU = int(os.environ.get("LOG_LRU", 0)) 772 LOG_JIT = int(os.environ.get("LOG_JIT", 0)) 773 LOG_TREE = int(os.environ.get("LOG_TREE", 0)) 774 LOG_BACKEND = int(os.environ.get("LOG_BACKEND", 0)) 775 LOG_PROGRAM = int(os.environ.get("LOG_PROGRAM", 0)) 776 LOG_INIT = int(os.environ.get("LOG_INIT", 1)) 777 device_var = os.environ.get("DEFAULT_DEVICE", "cpu:0") 778 if device_var[-2] != ":": 779 device_var += ":0" 780 DEFAULT_DEVICE = devices.name_idx_device_map[device_var] 781 DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")] 782 dtype_for_indices: DType = None # need to override 783 784 def __init__( 785 self, 786 operator_set: OperatorSet, 787 procedure_set: ProcedureSet, 788 ): 789 self.operator_set = operator_set 790 self.procedure_set = procedure_set 791 self.node_types = dict() 792 self.impls = dict() 793 self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple") 794 self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list") 795 self.register_node( 796 dict, 797 lambda d: list_map(tuple, unzip2(sorted(d.items()))), 798 lambda keys, vals: dict(list_zip(keys, vals)), 799 "dict", 800 ) 801 self.register_node( 802 UndefinedPrimal, 803 lambda u: (u.symval, ()), 804 lambda symval, _: UndefinedPrimal(symval), 805 "UndefinedPrimal", 806 ) 807 808 def set_impl(self, op: Union[types.LambdaType, types.FunctionType]): 809 def set_impl_(fn): 810 self.impls[op] = types.MethodType(fn, self) 811 812 return set_impl_ 813 814 def register_node(self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None: 815 if name is None: 816 name = str(ty) 817 self.node_types[ty] = NodeType(name, to_iter, from_iter) 818 819 def __getattr__(self, attr): 820 try: 821 dblog( 822 f"Looking {self}.{attr} in operator_set", 823 enable=backend.LOG_BACKEND, 824 ) 825 return getattr(self.operator_set, attr) 826 except: 827 pass 828 try: 829 dblog( 830 f"Looking {self}.{attr} in procedure_set", 831 enable=backend.LOG_BACKEND, 832 ) 833 return getattr(self.procedure_set, attr) 834 except: 835 pass 836 dblog( 837 f"Fallback to default {self} getattribute", 838 enable=backend.LOG_BACKEND, 839 ) 840 super().__getattribute__(attr) 841 842 def tensor( 843 self, 844 val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None, 845 dtype: Optional[Any] = None, 846 device=None, 847 ): 848 if isinstance(val, TensorBuffer): 849 return Tensor(val) 850 elif isinstance(val, Tensor): 851 return val 852 if type(val) is bytes: 853 val = np.frombuffer(val, dtype=dtype) 854 return self.from_numpy(val, dtype, device) 855 856 def symbolic_tensor( 857 self, 858 shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None, 859 dtype: Optional[Any] = None, 860 device=None, 861 ): 862 dtype = dtype or self.DEFAULT_DTYPE 863 device = device or self.DEFAULT_DEVICE 864 return SymbolicTensor(shape, dtype, device) 865 866 def seed(self, seed): 867 raise NotImplementedError 868 869 @property 870 def default_dtype_value(self): 871 return self.dtype_map[backend.DEFAULT_DTYPE] 872 873 def set_method(self, method): 874 setattr(self, method.__name__, types.MethodType(method, self)) 875 876 def from_numpy(self, val, device): 877 raise NotImplementedError 878 879 def numpy_of(self, tensor): 880 raise NotImplementedError 881 882 def device_of(self, tensor): 883 raise NotImplementedError 884 885 def shape_of(self, tensor): 886 raise NotImplementedError 887 888 def dtype_of(self, tensor): 889 raise NotImplementedError 890 891 @lru_cache_verbose() 892 def jit_program( 893 self, 894 hashed_program: Hashed, 895 hashed_consts: Tuple[Hashed, ...], 896 ): 897 program: Program = hashed_program.val 898 typecheck_program(program) 899 consts = [x.val for x in hashed_consts] 900 in_symvals = [v.symval for v in program.in_binders[len(consts) :]] 901 codegen_output: CodegenOutput = self.codegen(program, consts + in_symvals, fn_name="main") 902 fn, code = self.compile(codegen_output) 903 jit_output = JitOutput(program, codegen_output, fn, code, consts) 904 return jit_output 905 906 def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 907 "Returns compiler IR from the Program" 908 raise NotImplementedError 909 910 def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 911 "Compiles compiler IR to a Python callable function" 912 raise NotImplementedError 913 914 def export(self, jit_output, *args, **params): 915 raise NotImplementedError 916 917 def load(self, path, single_key="_tensor"): 918 with open(path, mode="rb") as f: 919 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m: 920 json_len = np.int64(m[0]) 921 start = 8 + json_len 922 metadata = json.loads(m[8:start]) 923 ret = {} 924 for k, v in metadata.items(): 925 if k != "__metadata__": 926 dtype = Tensor.mlir_dtype_map[(v["dtype"])] 927 data_start = start + v["data_offsets"][0] 928 data_end = start + v["data_offsets"][1] 929 t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy()) 930 t = backend.tensor(t_np, dtype=dtype) 931 t = t.reshape(tuple(v["shape"])) 932 ret[k] = t 933 if len(ret) == 1 and single_key in ret.keys(): 934 return ret[single_key] 935 return ret 936 937 def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"): 938 if isinstance(tensors, Tensor): 939 tensors = {single_key: tensors} 940 else: 941 assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items()) 942 943 metadata, offset = {}, 0 944 for k, v in tensors.items(): 945 metadata[k] = { 946 "dtype": v.dtype.mlir, 947 "shape": list(v.shape), 948 "data_offsets": [offset, offset + v.nbytes()], 949 } 950 offset += v.nbytes() 951 j = json.dumps(metadata, separators=(",", ":")) 952 Path(path).unlink(missing_ok=True) 953 jbytes = j.encode("utf-8") 954 start = 8 + len(jbytes) 955 with open(path, mode="wb") as f: # make empty file, fill with enough space 956 f.write(b"\x00" * (start + offset)) 957 with open(path, mode="r+b") as f: 958 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m: 959 m[0:8] = np.int64(len(j)).tobytes() 960 m[8:start] = jbytes 961 for t, tm in zip(tensors.values(), metadata.values()): 962 data_start, data_end = tm["data_offsets"] 963 m[start + data_start : start + data_end] = t.numpy().tobytes() 964 965 966# ================================= 967# Program 968# ================================= 969 970 971class Var: 972 def __init__(self, symval): 973 self.symval = symval 974 self.val = None 975 976 977class Lit: 978 def __init__(self, val): 979 self.symval = SymbolicTensor.like(get_symval(val)) 980 self.val = val 981 982 983Atom = Union[Var, Lit] 984 985 986class Instruction(NamedTuple): 987 op: Operator 988 inputs: List[Atom] 989 params: Dict[str, Any] 990 out_binders: List[Atom] 991 992 993class ProgramEnvVar(NamedTuple): 994 name: str 995 symval: SymbolicTensor 996 is_const: bool = False 997 998 @property 999 def shape(self): 1000 return self.symval.shape 1001 1002 @property 1003 def dtype(self): 1004 return self.symval.dtype 1005 1006 @property 1007 def device(self): 1008 return self.symval.device 1009 1010 @property 1011 def ndim(self): 1012 return self.symval.ndim 1013 1014 def numpy(self): 1015 return self.symval.numpy() 1016 1017 def __repr__(self): 1018 return f"<ProgramEnvVar: name={self.name}, symval={self.symval}>" 1019 1020 str_short = __repr__ 1021 1022 1023class Program: 1024 def __init__( 1025 self, 1026 in_binders: Any, 1027 instructions: Tuple[Instruction], 1028 outs: Any, 1029 num_consts: int = 0, 1030 static_args: Any = (), 1031 name: str = "my_program", 1032 indent_amount=4, 1033 ): 1034 self.in_binders: Any = in_binders 1035 self.outs: Any = outs 1036 self.instructions = self.prune_instructions(instructions, outs) 1037 self.num_consts: int = num_consts 1038 self.static_args = static_args 1039 self.name: str = name 1040 self.indent_amount: int = indent_amount 1041 1042 self.env: Dict[ProgramEnvVar, Any] = dict() 1043 for inb in self.in_binders: 1044 prefix = "x" if type(inb.symval) is SymbolicTensor else "c" 1045 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1046 self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False) 1047 for instruction in self.instructions: 1048 if len(instruction.out_binders) == 0: 1049 continue 1050 for outb in instruction.out_binders: 1051 prefix = "y" if outb in self.outs else "z" 1052 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1053 self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval) 1054 self.curr_repr = repr(self) 1055 1056 def pprint_shape(self, symval, scalar_as_empty_array=False): 1057 xdtype = symval.dtype.mlir 1058 if len(symval.shape) > 0: 1059 xshape = f"{', '.join((repr(i) for i in symval.shape))}" 1060 return f"[{xshape}, {xdtype}]" 1061 else: 1062 return f"[{xdtype}]" 1063 1064 def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False): 1065 in_code = ", ".join(self.pprint_shape(t) for t in in_symvals) 1066 in_code = f"({in_code})" if len(in_symvals) > 1 else in_code 1067 out_code = ", ".join(self.pprint_shape(t) for t in out_symvals) 1068 out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code 1069 typing_code = f"{in_code} -> {out_code}" 1070 return typing_code 1071 1072 def __repr__(self): 1073 fn_defs = self.instructions_as_code(self, dict()) 1074 return "\n".join(line for code_lines in fn_defs.values() for line in code_lines) 1075 1076 def save(self, *args, dir_path="/tmp/slope_program", dry_run=False): 1077 os.makedirs(dir_path, exist_ok=True) 1078 head_code_lines = [f"import slope # backend={backend.__class__.__name__}"] 1079 fn_defs = self.instructions_as_code(self, dict()) 1080 in_binders_vars = [self.env[i] for i in self.in_binders] 1081 for i in range(len(self.in_binders)): 1082 ibv = in_binders_vars[i] 1083 if ibv.is_const: 1084 const_filename = f"{ibv.name}.safetensors" 1085 const_path = os.path.join(dir_path, f"{const_filename}") 1086 if not dry_run: 1087 backend.save(args[i], const_path) 1088 dblog( 1089 f"Saved {ibv.name} at {const_path}", 1090 enable=backend.LOG_BACKEND, 1091 ) 1092 head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""] 1093 head_code_lines += [""] 1094 code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines]) 1095 dblog( 1096 f"Contents of {self.name}:\n```\n{code}\n```", 1097 enable=backend.LOG_BACKEND, 1098 ) 1099 program_path = os.path.join(dir_path, "main.py") 1100 if not dry_run: 1101 with open(program_path, "w") as f: 1102 f.write(code) 1103 dblog( 1104 f"Saved program {self.name} at {program_path}", 1105 enable=backend.LOG_BACKEND, 1106 ) 1107 ls_contents = "\n\t".join(os.listdir(dir_path)) 1108 dblog( 1109 f"Contents of {dir_path}:\n\t{ls_contents}", 1110 enable=backend.LOG_BACKEND, 1111 ) 1112 1113 def __hash__(self): 1114 return hash(self.curr_repr) 1115 1116 def __eq__(self, other): 1117 return self is other 1118 1119 @classmethod 1120 def instructions_as_code(cls, program, fn_defs): 1121 def indent(code, indent_amount): 1122 spaces = " " * (len(code) - len(code.lstrip())) 1123 spaces += " " * indent_amount 1124 return "\n".join([spaces + line for line in code.strip().split("\n")]) 1125 1126 in_binders_vars = [program.env[i] for i in program.in_binders] 1127 body_code_lines = [] 1128 for instruction in program.instructions: 1129 if len(instruction.out_binders) == 0: 1130 continue 1131 params = instruction.params.copy() 1132 for param_name, param in params.items(): 1133 if isinstance(param, Program): 1134 sub_program = param 1135 fn_defs = cls.instructions_as_code(sub_program, fn_defs) 1136 program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1137 params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]" 1138 if isinstance(param, DType): 1139 params[param_name] = f"slope.{param.name}" 1140 param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items()) 1141 in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1142 out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders) 1143 sig = program.pprint_sig( 1144 [program.env[x].symval for x in instruction.inputs], 1145 [program.env[y].symval for y in instruction.out_binders], 1146 ) 1147 line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}""" 1148 body_code_lines += [indent(line, program.indent_amount)] 1149 1150 fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars]) 1151 # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args]) 1152 out_vars = [program.env[o] for o in program.outs] 1153 fn_sig = program.pprint_sig( 1154 [i.symval for i in in_binders_vars], 1155 [o.symval for o in out_vars], 1156 ) 1157 head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"] 1158 out_str = ", ".join([f"{o.name}" for o in out_vars]) 1159 tail_code_line = [indent(f"return {out_str}", program.indent_amount)] 1160 code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"] 1161 1162 fn_defs[program.name] = code_lines 1163 return fn_defs 1164 1165 @staticmethod 1166 def prune_instructions(instructions, outs): 1167 graph = dict() 1168 for instruction in instructions: 1169 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1170 for parent in parent_nodes: 1171 if parent not in graph: 1172 graph[parent] = set() 1173 for child in child_nodes: 1174 graph[parent].add(child) 1175 visited_from_terminal = set() 1176 1177 def dfs(node, visited): 1178 visited.add(node) 1179 if node in graph: 1180 for neighbor in graph[node]: 1181 if neighbor not in visited: 1182 dfs(neighbor, visited) 1183 1184 for terminal_node in outs: 1185 dfs(terminal_node, visited_from_terminal) 1186 unreachable_nodes = set(graph.keys()) - visited_from_terminal 1187 1188 instructions_to_prune = [] 1189 for instruction in instructions: 1190 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1191 if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes): 1192 instructions_to_prune += [instruction] 1193 new_instructions = [inst for inst in instructions if inst not in instructions_to_prune] 1194 if backend.LOG_PROGRAM: 1195 LI = len(instructions) 1196 LNI = len(new_instructions) 1197 DIFF = LI - LNI 1198 UN = len(unreachable_nodes) 1199 dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}") 1200 return new_instructions 1201 1202 1203class ProgramType(NamedTuple): 1204 in_types: Tuple[SymbolicTensor] 1205 out_types: Tuple[SymbolicTensor] 1206 1207 def __repr__(self): 1208 in_types = ", ".join(symval.str_short() for symval in self.in_types) 1209 out_types = ", ".join(symval.str_short() for symval in self.out_types) 1210 return f"({in_types}) -> ({out_types})" 1211 1212 1213# ================================= 1214# Tracer and Trace 1215# ================================= 1216 1217 1218class Empty: 1219 pass 1220 1221 1222empty = Empty() 1223 1224 1225class Store: 1226 val = empty 1227 1228 def set_value(self, val): 1229 assert self.val is empty 1230 self.val = val 1231 1232 def __call__(self): 1233 return self.val 1234 1235 1236class NodeType(NamedTuple): 1237 name: str 1238 flatten: Callable 1239 unflatten: Callable 1240 1241 1242class TreeDef(NamedTuple): 1243 node_type: NodeType 1244 node_metadata: Hashable 1245 child_treedefs: Tuple["TreeDef", ...] 1246 1247 def __repr__(self): 1248 ret = self.tree_repr(self) 1249 return ret 1250 1251 def tree_repr(self, tree, indent=" ", prefix="", last=True): 1252 ret = "" 1253 1254 def _tree_repr(tree, indent, prefix, last): 1255 nonlocal ret 1256 if isinstance(tree, TreeDef): 1257 ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n' 1258 for i, item in enumerate(tree.child_treedefs): 1259 new_prefix = prefix + (indent if not last else " ") 1260 new_last = i == len(tree.child_treedefs) - 1 1261 _tree_repr(item, indent, new_prefix, new_last) 1262 else: 1263 ret += f'{prefix} {("└─" if last else "├─")} {tree}\n' 1264 1265 _tree_repr(tree, indent=" ", prefix="", last=True) 1266 return ret 1267 1268 @property 1269 def num_leaves(self): 1270 def get_num_leaves(x): 1271 if isinstance(x, Leaf): 1272 return 1 1273 else: 1274 return sum(get_num_leaves(sub_x) for sub_x in x.child_treedefs) 1275 1276 return sum(get_num_leaves(x) for x in self.child_treedefs) 1277 1278 1279class Leaf: 1280 def __init__(self, val): 1281 if hasattr(val, "shape"): 1282 val = SymbolicTensor.like(val) 1283 self.val = val 1284 1285 def __repr__(self): 1286 ret = self.val.str_short() if isinstance(self.val, SymbolicTensor) else repr(self.val) 1287 return f"<Leaf: {ret}>" 1288 1289 def __hash__(self): 1290 return hash(self.val) 1291 1292 def __eq__(self, other): 1293 return True # make TreeDef __eq__ don't care Leaf 1294 # if isinstance(other, Leaf): # TODO: test above assumption 1295 # return self.val == other.val 1296 1297 1298# ================================= 1299# jit operator 1300# ================================= 1301 1302 1303class JitOutput: 1304 def __init__(self, program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any]): 1305 super().__init__() 1306 self.program = program 1307 self.code = code 1308 self.codegen_output = codegen_output 1309 self.fn: Callable = fn 1310 self.consts = consts 1311 1312 def __call__(self, *args, **params): 1313 args, in_tree = tree_flatten(args) 1314 args = tree_map(lambda a: a.val if isinstance(a, Tensor) else a, args) 1315 try: 1316 outs = self.fn(*args, **params) 1317 if not isinstance(outs, tuple): # TODO: IREE FunctionInvoker destructure 1-tuple, need to undo 1318 outs = (outs,) 1319 except Exception as e: 1320 dblog(self.code, enable=backend.LOG_JIT) 1321 raise 1322 return [backend.tensor(TensorBuffer(o)) for o in outs] 1323 1324 1325class JitOp(MetaOperator): 1326 def meta_impl(self, *args, program: Program, **_): 1327 hashed_program = Hashed(program) 1328 num_consts = program.num_consts 1329 consts, args = args[:num_consts], args[num_consts:] 1330 hashed_consts = tuple(map(Hashed, consts)) 1331 jit_output = backend.jit_program(hashed_program, hashed_consts) 1332 ret = jit_output(*consts, *args) 1333 return ret 1334 1335 def reorg_args(self, args, params): 1336 return args, params 1337 1338 def typecheck(self, *in_types, program: Program): 1339 program_type = typecheck_program(program) 1340 if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)): 1341 ret = "Type mismatch program.in_types vs in_types:\n" 1342 for i, j in zip(program_type.in_types, in_types): 1343 ret += f"{i}, {j}, {i == j}" 1344 raise TypeError(ret) 1345 return program_type.out_types 1346 1347 def vmap(self, dim_size, vals_in, dims_in, program: Program): 1348 program, consts = vmap_program(program, dim_size, tuple(dims_in)) 1349 outs = self(*consts, *vals_in, program=program) 1350 if not isinstance(outs, tuple): 1351 outs = (outs,) 1352 return outs, [0] * len(outs) 1353 1354 def jvp(self, primals, tangents, *, program): 1355 new_program, new_consts = jvp_program(program) 1356 outs = bind( 1357 self, 1358 *new_consts, 1359 *primals, 1360 *tangents, 1361 program=new_program, 1362 ) 1363 n = len(outs) // 2 1364 primals_out, tangents_out = outs[:n], outs[n:] 1365 return primals_out, tangents_out 1366 1367 def T(self, cotangents, *invals, program): 1368 undef_primals = [isinstance(x, UndefinedPrimal) for x in invals] 1369 transposed_program, new_consts = transpose_program(program, tuple(undef_primals)) 1370 1371 residuals, _ = partition_list(undef_primals, invals) 1372 outs = bind( 1373 self, 1374 *new_consts, 1375 *residuals, 1376 *cotangents, 1377 program=transposed_program, 1378 ) 1379 outs = iter(outs) 1380 1381 return [next(outs) if undef else None for undef in undef_primals] 1382 1383 def partial_run(self, trace, tracers, *, program): 1384 in_unknowns = [not t.pval.is_known for t in tracers] 1385 program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns) 1386 known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) 1387 known_vals = [t.pval.const for t in known_tracers] 1388 outs1_res = bind(backend.jit_op, *known_vals, program=program1) 1389 outs1, res = split_list(outs1_res, len(program1.outs) - num_res) 1390 res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] 1391 outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs] 1392 instruction = InstructionDraft( 1393 self, 1394 res_tracers + unknown_tracers, 1395 dict(program=program2), 1396 [v.symval for v in program2.outs], 1397 list_map(weakref.ref, outs2), 1398 ) 1399 for t in outs2: 1400 t.draft = instruction 1401 1402 return merge_lists(out_unknowns, outs1, outs2) 1403 1404 def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]: 1405 program = instruction.params["program"] 1406 program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in) 1407 ins1, ins2 = partition_list(unks_in, instruction.inputs) 1408 out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders) 1409 res = [Var(v.symval) for v in program2.in_binders[:num_res]] 1410 instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res) 1411 instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2) 1412 return instruction1, instruction2, out_unknowns, res 1413 1414 1415class MainTrace(NamedTuple): 1416 level: int 1417 trace_type: Type["Trace"] 1418 global_data: Optional[Any] 1419 1420 1421class Trace: 1422 main: MainTrace 1423 1424 def __init__(self, main: MainTrace) -> None: 1425 self.main = main 1426 1427 def pure(self, val): 1428 raise NotImplementedError 1429 1430 def run_op(self, op, tracers, params): 1431 raise NotImplementedError 1432 1433 1434class RunTrace(Trace): 1435 pure = lambda self, x: x 1436 1437 def run_op(self, op: Operator, args, params): 1438 if isinstance(op, MetaOperator): 1439 args, params = op.reorg_args(args, params) 1440 args, params = op.args_fixer(*args, **params) 1441 ret = op.meta_impl(*args, **params) 1442 else: 1443 fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params) 1444 # with Timing(f"RUN {op}"):ret = jit( 1445 ret = jit( 1446 fn, 1447 static_argnames=("params",), 1448 name=jit.get_jit_name(args, params, op.name), 1449 )(*args, **params) 1450 1451 return ret 1452 1453 @staticmethod 1454 @lru_cache_verbose() 1455 def get_fn(op, *symval_args, **params): 1456 def fn(*args, **params): 1457 return [op(*args, **params)] 1458 1459 return fn 1460 1461 1462class SymbolicRunTrace(Trace): 1463 # pure = lambda self, x: x 1464 def pure(self, val: Any) -> SymbolicTensor: 1465 return val.symval 1466 1467 def run_op(self, op, tracers, params): 1468 symvals_in = tree_map(lambda x: x.symval, tracers) 1469 symvals_out = op.typecheck(*symvals_in, **params) 1470 return symvals_out 1471 1472 1473class TraceTensor(Tensor): 1474 PYTHON_TYPES = { 1475 bool, 1476 int, 1477 float, 1478 } 1479 _trace: "Trace" 1480 1481 def __init__(self, *args, **kwargs): 1482 raise NotImplementedError 1483 1484 symval = property(lambda self: get_symval(self.val)) 1485 dtype = property(lambda self: self.symval.dtype) 1486 shape = property(lambda self: self.symval.shape) 1487 device = property(lambda self: self.symval.device) 1488 1489 @property 1490 def val(self): 1491 raise NotImplementedError 1492 1493 def __str__(self): 1494 return repr(self) 1495 1496 def full_lower(self): 1497 return self 1498 1499 @property 1500 def ndim(self): 1501 return len(self.shape) 1502 1503 def __repr__(self): 1504 return f"{self.__class__.__name__}({repr(self.symval)})" 1505 1506 1507class VMapTraceTensor(TraceTensor): 1508 def __init__(self, trace, val, vmap_dim): 1509 self._trace = trace 1510 self._val = val 1511 self.vmap_dim = vmap_dim 1512 1513 @property 1514 def val(self): 1515 return self._val 1516 1517 @property 1518 def symval(self): 1519 symval = get_symval(self.val) 1520 if self.vmap_dim is None: 1521 return symval 1522 else: 1523 shape = list(symval.shape) 1524 del shape[self.vmap_dim] 1525 return symval.like(shape=tuple(shape)) 1526 1527 def full_lower(self): 1528 if self.vmap_dim is None: 1529 return full_lower(self.val) 1530 else: 1531 return self 1532 1533 1534class VMapTrace(Trace): 1535 pure = lambda self, val: VMapTraceTensor(self, val, None) 1536 1537 @property 1538 def dim_size(self): 1539 return self.main.global_data 1540 1541 def run_op(self, op, tracers, params): 1542 vals_in, bdims_in = unzip2((t.val, t.vmap_dim) for t in tracers) 1543 val_outs, bdim_outs = op.vmap(self.dim_size, vals_in, bdims_in, **params) 1544 return [VMapTraceTensor(self, x, bd) for x, bd in list_zip(val_outs, bdim_outs)] 1545 1546 @staticmethod 1547 def move_vmap_dim(x, dim_size, src: int, dst: int): 1548 if src is None: # unsqueeze and expand 1549 target_shape = list(x.shape) 1550 target_shape.insert(dst, dim_size) 1551 unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))] 1552 x = x.reshape(tuple(unsqueeze_shape)) 1553 x = x.expand(tuple(target_shape)) 1554 return x 1555 elif src == dst: 1556 return x 1557 else: 1558 perm = [i for i in range(len(x.shape)) if i != src] 1559 perm.insert(dst, src) 1560 return x.permute(tuple(perm)) 1561 1562 1563class JVPTraceTensor(TraceTensor): 1564 def __init__(self, trace, primal, tangent): 1565 self._trace = trace 1566 self.primal = primal 1567 self.tangent = tangent 1568 1569 @property 1570 def symval(self): 1571 return get_symval(self.primal) 1572 1573 @property 1574 def val(self): 1575 return self.primal 1576 1577 @property 1578 def dtype(self): 1579 return self.primal.dtype 1580 1581 1582class JVPTrace(Trace): 1583 def pure(self, val): 1584 if isinstance(val, PartialRunTrace): 1585 val = val.pval.const 1586 return JVPTraceTensor(self, val, backend.zeros_like(val)) 1587 1588 def run_op(self, op, tracers, params): 1589 primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) 1590 primals_out, tangents_out = op.jvp(primals_in, tangents_in, **params) 1591 return [JVPTraceTensor(self, x, t) for x, t in list_zip(primals_out, tangents_out)] 1592 1593 1594class ProgramTraceTensor(TraceTensor): 1595 __slots__ = ["symval"] 1596 symval: SymbolicTensor 1597 1598 def __init__(self, trace, symval): 1599 self._trace = trace 1600 self.symval = symval 1601 1602 1603class ProgramTrace(Trace): 1604 @property 1605 def builder(self): 1606 return self.main.global_data 1607 1608 def new_arg(self, symval) -> ProgramTraceTensor: 1609 symval = SymbolicTensor.like(symval) 1610 tracer = self.builder.new_tracer(self, symval) 1611 self.builder.tracer_to_var[id(tracer)] = Var(symval) 1612 1613 return tracer 1614 1615 def pure(self, val: Any) -> ProgramTraceTensor: 1616 # get_or_make_const_tracer 1617 tracer = self.builder.const_tracers.get(id(val)) 1618 if tracer is None: 1619 tracer = self.builder.new_tracer(self, get_symval(val)) 1620 self.builder.add_const(tracer, val) 1621 # print(self.builder.const_tracers) 1622 return tracer 1623 1624 def run_op(self, op, tracers, params): 1625 symvals_in = tree_map(lambda x: x.symval, tracers) 1626 symvals_out = op.typecheck(*symvals_in, **params) 1627 1628 out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out] 1629 inputs = [self.builder.getvar(t) for t in tracers] 1630 outvars = [self.builder.add_var(t) for t in out_tracers] 1631 1632 self.builder.add_instruction(Instruction(op, inputs, params, outvars)) 1633 return out_tracers 1634 1635 1636class ProgramBuilder: 1637 instructions: List[Instruction] 1638 tracer_to_var: Dict[int, Var] 1639 const_tracers: Dict[int, TraceTensor] 1640 constvals: Dict[Var, Any] 1641 tracers: List[ProgramTraceTensor] 1642 1643 def __init__(self): 1644 self.instructions = [] 1645 self.tracer_to_var = {} 1646 self.const_tracers = {} 1647 self.constvals = {} 1648 self.tracers = [] 1649 1650 def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor: 1651 tracer = ProgramTraceTensor(trace, symval) 1652 self.tracers += [tracer] 1653 return tracer 1654 1655 def add_instruction(self, instruction: Instruction) -> None: 1656 self.instructions += [instruction] 1657 1658 def add_var(self, tracer: ProgramTraceTensor) -> Var: 1659 assert id(tracer) not in self.tracer_to_var 1660 var = self.tracer_to_var[id(tracer)] = Var(tracer.symval) 1661 return var 1662 1663 def getvar(self, tracer: ProgramTraceTensor) -> Var: 1664 var = self.tracer_to_var.get(id(tracer)) 1665 assert var is not None 1666 return var 1667 1668 def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var: 1669 var = self.add_var(tracer) 1670 self.const_tracers[id(val)] = tracer 1671 self.constvals[var] = val 1672 return var 1673 1674 def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]: 1675 constvars, constvals = unzip2(self.constvals.items()) 1676 t2v = lambda t: self.tracer_to_var[id(t)] 1677 in_binders = constvars + [t2v(t) for t in in_tracers] 1678 out_vars = [t2v(t) for t in out_tracers] 1679 program = Program( 1680 in_binders, 1681 self.instructions, 1682 out_vars, 1683 len(constvals), 1684 static_args, 1685 name, 1686 ) 1687 typecheck_program(program) 1688 program, constvals = self._inline_literals(program, constvals) 1689 typecheck_program(program) 1690 # dblog(program, enable=backend.LOG_PROGRAM) 1691 return program, constvals 1692 1693 def _inline_literals(self, program: Program, consts: List[Any]) -> Tuple[Program, List[Any]]: 1694 const_binders, other_binders = split_list(program.in_binders, len(consts)) 1695 scalars = [type(x) in TraceTensor.PYTHON_TYPES and not get_symval(x).shape for x in consts] 1696 new_const_binders, lit_binders = partition_list(scalars, const_binders) 1697 new_consts, lit_vals = partition_list(scalars, consts) 1698 literals = dict(list_zip(lit_binders, list_map(Lit, lit_vals))) 1699 new_outs = [literals.get(x, x) for x in program.outs] 1700 new_instructions = [ 1701 Instruction( 1702 instruction.op, 1703 [literals.get(x, x) for x in instruction.inputs], 1704 instruction.params, 1705 instruction.out_binders, 1706 ) 1707 for instruction in program.instructions 1708 ] 1709 new_program = Program( 1710 new_const_binders + other_binders, 1711 new_instructions, 1712 new_outs, 1713 len(new_consts), 1714 program.static_args, 1715 program.name, 1716 ) 1717 return new_program, tuple(new_consts) 1718 1719 def get_current_scope_info(self): 1720 current_frame = inspect.currentframe() 1721 current_function_name = current_frame.f_code.co_name 1722 current_module_name = inspect.getmodulename(current_frame.f_code.co_filename) 1723 current_class_name = None 1724 for frame_info in inspect.getouterframes(current_frame): 1725 print(frame_info) 1726 frame_locals = frame_info.frame.f_locals 1727 print(frame_locals) 1728 if "self" in frame_locals: 1729 current_class_name = frame_locals["self"].__class__.__name__ 1730 break 1731 return { 1732 "Function": current_function_name, 1733 "Module": current_module_name, 1734 "Class": current_class_name, 1735 } 1736 1737 1738class UndefinedPrimal(NamedTuple): 1739 symval: SymbolicTensor 1740 1741 @property 1742 def shape(self): 1743 return self.symval.shape 1744 1745 @property 1746 def dtype(self): 1747 return self.symval.dtype 1748 1749 @property 1750 def device(self): 1751 return self.symval.device 1752 1753 @property 1754 def ndim(self): 1755 return self.symval.ndim 1756 1757 def __repr__(self): 1758 return f"<UndefinedPrimal: symval={self.symval}>" 1759 1760 str_short = __repr__ 1761 1762 1763class PartialValue(NamedTuple): 1764 symval: SymbolicTensor 1765 const: Optional[Any] 1766 1767 is_known = property(lambda self: self.const is not None) 1768 is_unknown = property(lambda self: self.const is None) 1769 1770 1771class LambdaBindingDraft(NamedTuple): 1772 pass 1773 1774 1775class ConstDraft(NamedTuple): 1776 val: Any 1777 1778 1779class InstructionDraft(NamedTuple): 1780 prim: Operator 1781 tracers_in: List["PartialRunTraceTensor"] 1782 params: Dict[str, Any] 1783 symvals_out: List[SymbolicTensor] 1784 tracer_refs_out: List[weakref.ReferenceType["PartialRunTraceTensor"]] 1785 1786 1787ProgramDraft = Union[LambdaBindingDraft, ConstDraft, InstructionDraft] 1788 1789 1790class PartialRunTraceTensor(TraceTensor): 1791 def __init__(self, trace, pval, draft): 1792 self._trace = trace 1793 self.pval = pval 1794 self.draft = draft 1795 1796 symval = property(lambda self: self.pval.symval) 1797 val = property(lambda self: self.pval.const) 1798 1799 def full_lower(self): 1800 if self.pval.is_known: 1801 return full_lower(self.pval.const) 1802 return self 1803 1804 1805class PartialRunTrace(Trace): 1806 def new_arg(self, pval: PartialValue) -> Any: 1807 return PartialRunTraceTensor(self, pval, LambdaBindingDraft()) 1808 1809 def pure(self, val: Any) -> PartialRunTraceTensor: 1810 return PartialRunTraceTensor(self, make_known_pval(val), None) 1811 1812 def instantiate_const(self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor: 1813 if tracer.pval.is_unknown: 1814 return tracer 1815 else: 1816 pval = make_unknown_pval(SymbolicTensor.like(tracer.symval)) 1817 return PartialRunTraceTensor(self, pval, ConstDraft(tracer.pval.const)) 1818 1819 def run_op(self, op, tracers, params): 1820 is_knowns = tuple(t.pval.is_known for t in tracers) 1821 1822 if all(is_knowns): 1823 return bind(op, *list_map(full_lower, tracers), **params) 1824 return op.partial_run(self, tracers, **params) 1825 1826 1827trace_stack: List[MainTrace] = [] 1828stashed_trace: Optional[MainTrace] = None 1829trace_stack += [MainTrace(0, RunTrace, None)] 1830 1831 1832class UndefBackend: 1833 def __getattr__(self, attr): 1834 raise NotImplementedError("Backend not init yet with slope.core.set_backend(backend)") 1835 1836 1837backend = UndefBackend() 1838 1839 1840def set_backend(name, where="slope.backends"): 1841 global backend 1842 backend = importlib.import_module(f"{where}.{name}").backend 1843 import slope.nn as nn 1844 1845 # backend.register_node(nn.Module, nn.Module.flatten, nn.Module.unflatten, "Module") 1846 1847 dblog(f"slope backend is {backend}", enable=backend.LOG_INIT) 1848 1849 1850def stack_str(): 1851 ret = "" 1852 for trace in trace_stack: 1853 ret += f"{trace.level}: {trace.trace_type.__name__}\t{trace.global_data=}\n" 1854 return ret 1855 1856 1857def make_known_pval(val: Any): 1858 return PartialValue(get_symval(val), val) 1859 1860 1861def make_unknown_pval(symval: SymbolicTensor): 1862 return PartialValue(symval, None) 1863 1864 1865def get_symval(x): 1866 if isinstance(x, TraceTensor): 1867 return x.symval 1868 elif type(x) in TraceTensor.PYTHON_TYPES: 1869 return backend.tensor(x) 1870 elif isinstance(x, Tensor): 1871 return x 1872 elif isinstance(x, SymbolicTensor): 1873 return x 1874 else: 1875 raise TypeError(type(x)) 1876 1877 1878def tree_flatten(x: Any) -> Any: 1879 def _tree_flatten(x_: Any) -> Tuple[Iterable, Union[TreeDef, Leaf]]: 1880 node_type = None 1881 for k in backend.node_types.keys(): 1882 if isinstance(x_, k): 1883 node_type = backend.node_types[k] 1884 1885 if node_type is not None: 1886 node_metadata, children = node_type.flatten(x_) 1887 children_flat, child_trees = unzip2(list_map(_tree_flatten, children)) 1888 children_iter = itertools.chain.from_iterable(children_flat) 1889 treedef = TreeDef(node_type, node_metadata, tuple(child_trees)) 1890 return children_iter, treedef 1891 else: 1892 return (x_,), Leaf(x_) 1893 1894 children_iter, treedef = _tree_flatten(x) 1895 return tuple(children_iter), treedef 1896 1897 1898def tree_unflatten(treedef: TreeDef, xs: Tuple[Any]) -> Any: 1899 def _tree_unflatten(treedef_: TreeDef, xs_: Iterator) -> Any: 1900 if isinstance(treedef_, Leaf): 1901 dblog(f" tree leaf found: {xs_}\n", enable=backend.LOG_TREE) 1902 return next(xs_) 1903 else: 1904 dblog(f" now\n {treedef_}", enable=backend.LOG_TREE) 1905 children = (_tree_unflatten(t, xs_) for t in treedef_.child_treedefs) 1906 dblog(f"{children=}\n", enable=backend.LOG_TREE) 1907 return treedef_.node_type.unflatten(treedef_.node_metadata, children) 1908 1909 dblog(f"unflattening {treedef}", enable=backend.LOG_TREE) 1910 return _tree_unflatten(treedef, iter(xs)) 1911 # with Timing(f"\nTREE:\n{treedef}"): 1912 # ret = _tree_unflatten(treedef, iter(xs)) 1913 # return ret 1914 1915 1916def tree_transpose( 1917 outer_treedef: TreeDef, 1918 inner_treedef: TreeDef, 1919 tree_to_transpose: Any, 1920) -> Any: 1921 flat, treedef = tree_flatten(tree_to_transpose) 1922 inner_size = inner_treedef.num_leaves 1923 outer_size = outer_treedef.num_leaves 1924 if treedef.num_leaves != (inner_size * outer_size): 1925 raise TypeError 1926 iter_flat = iter(flat) 1927 lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)] 1928 permuted_lol = zip(*lol) 1929 subtrees = map(partial(tree_unflatten, outer_treedef), permuted_lol) 1930 return tree_unflatten(inner_treedef, subtrees) 1931 1932 1933def flatten_fn(f, in_tree, *, has_aux=False): 1934 store = Store() 1935 1936 def flat_fn(*args_flat, **params): 1937 tree_args = tree_unflatten(in_tree, args_flat) 1938 out = f(*tree_args, **params) 1939 if has_aux: 1940 out, aux = out 1941 out_flat, out_tree = tree_flatten(out) 1942 store.set_value(out_tree) 1943 return (out_flat, aux) if has_aux else out_flat 1944 1945 return flat_fn, store 1946 1947 1948def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any: 1949 leaves, treedef = tree_flatten(tree) 1950 if len(rest) == 0: 1951 out_tree_flat = tuple(f(leaf) for leaf in leaves) 1952 out_tree = tree_unflatten(treedef, out_tree_flat) 1953 else: 1954 all_leaves = [leaves] 1955 for t in rest: 1956 t_leaves, t_treedef = tree_flatten(t) 1957 assert t_treedef == treedef 1958 all_leaves += [t_leaves] 1959 1960 out_tree_flat = tuple(f(*xs) for xs in zip(*all_leaves)) 1961 out_tree = tree_unflatten(treedef, out_tree_flat) 1962 ret = out_tree 1963 if out_leaf: 1964 ret = (ret, tree_flatten(out_tree_flat[0])) 1965 return ret 1966 1967 1968@contextmanager 1969def new_main_trace(trace_type: Type["Trace"], global_data=None): 1970 global trace_stack 1971 level = len(trace_stack) 1972 main = MainTrace(level, trace_type, global_data) 1973 trace_stack += [main] 1974 1975 try: 1976 yield main 1977 finally: 1978 trace_stack.pop() 1979 1980 1981def bind(op, *args, **params): 1982 top_trace = find_top_trace(args) 1983 tracers = tuple([full_raise(top_trace, arg) for arg in args]) 1984 outs = top_trace.run_op(op, tracers, params) 1985 lowered = tuple([full_lower(out) for out in outs]) 1986 return lowered 1987 1988 1989def find_top_trace(xs) -> Trace: 1990 arrs = [] 1991 1992 def get_arr_from_seq(seq): 1993 nonlocal arrs 1994 for x in seq: 1995 if type(x) in (tuple, list): 1996 get_arr_from_seq(x) 1997 elif isinstance(x, TraceTensor): 1998 arrs += [x] 1999 2000 get_arr_from_seq(xs) 2001 arrs = tuple(arrs) 2002 top_main = max( 2003 (x._trace.main for x in arrs), 2004 default=trace_stack[0], 2005 key=operator_py.attrgetter("level"), 2006 ) 2007 if stashed_trace and stashed_trace.level > top_main.level: 2008 top_main = stashed_trace 2009 return top_main.trace_type(top_main) 2010 2011 2012def full_raise(trace: Trace, val: Any) -> TraceTensor: 2013 if not isinstance(val, TraceTensor): 2014 return trace.pure(val) 2015 level = trace.main.level 2016 if val._trace.main is trace.main: 2017 return val 2018 elif val._trace.main.level < level: 2019 return trace.pure(val) 2020 elif val._trace.main.level > level: 2021 raise Exception(f"Can't lift level {val._trace.main.level} to {level}.") 2022 else: 2023 raise Exception(f"Different traces at same level: {val._trace}, {trace}.") 2024 2025 2026def full_lower(val: Any): 2027 if isinstance(val, TraceTensor): 2028 return val.full_lower() 2029 elif type(val) in (list, tuple): 2030 return tuple(full_lower(v) for v in val) 2031 else: 2032 return val 2033 2034 2035def typecheck_program(program: Program) -> ProgramType: 2036 env: Set[Var] = set() 2037 2038 for v in program.in_binders: 2039 if v in env: 2040 raise TypeError 2041 env.add(v) 2042 2043 for instruction in program.instructions: 2044 in_types = [typecheck_atom(env, x) for x in instruction.inputs] 2045 out_types = instruction.op.typecheck(*in_types, **instruction.params) 2046 for out_binder, out_type in list_zip(instruction.out_binders, out_types): 2047 if not out_type == out_binder.symval: 2048 raise TypeError 2049 for out_binder in instruction.out_binders: 2050 if out_binder in env: 2051 raise TypeError 2052 env.add(out_binder) 2053 2054 in_types = [v.symval for v in program.in_binders] 2055 out_types = [typecheck_atom(env, x) for x in program.outs] 2056 return ProgramType(tuple(in_types), tuple(out_types)) 2057 2058 2059def typecheck_atom(env: Set[Var], x: Atom) -> SymbolicTensor: 2060 if isinstance(x, Var): 2061 if x not in env: 2062 raise TypeError("unbound variable") 2063 return x.symval 2064 elif isinstance(x, Lit): 2065 return get_symval(x.val) 2066 else: 2067 assert False 2068 2069 2070def run_program(program: Program, args: List[Any]) -> List[Any]: 2071 env: Dict[Var, Any] = {} 2072 2073 def read(x: Atom) -> Any: 2074 return env[x] if type(x) is Var else x.val 2075 2076 def write(v: Var, val: Any) -> None: 2077 assert v not in env # single-assignment 2078 env[v] = val 2079 2080 list_map(write, program.in_binders, args) 2081 for instruction in program.instructions: 2082 in_vals = list_map(read, instruction.inputs) 2083 outs = bind(instruction.op, *in_vals, **instruction.params) 2084 list_map(write, instruction.out_binders, outs) 2085 return list_map(read, program.outs) 2086 2087 2088def program_as_fun(program: Program): 2089 return lambda *args: run_program(program, args) 2090 2091 2092def vmap_flat(f, in_dim, out_dim, dim_size, *args): 2093 if dim_size is None: 2094 dims = set([x.shape[d] for x, d in list_zip(args, in_dim) if d is not None]) 2095 assert len(dims) == 1 2096 (dim_size,) = dims 2097 with new_main_trace(VMapTrace, dim_size) as main: 2098 trace = VMapTrace(main) 2099 tracers_in = [VMapTraceTensor(trace, x, dim) if dim is not None else x for x, dim in list_zip(args, in_dim)] 2100 outs = f(*tracers_in) 2101 tracers_out = [full_raise(trace, out) for out in outs] 2102 vals_out, y_vmap_dims = unzip2((t.val, t.vmap_dim) for t in tracers_out) 2103 ret = [VMapTrace.move_vmap_dim(val_out, dim_size, bdim, out_dim) for val_out, bdim, out_dim in zip(vals_out, y_vmap_dims, out_dim)] 2104 return ret 2105 2106 2107def vmap(f, in_dim=0, out_dim=0, dim_size=None): 2108 def batched_f(*args): 2109 nonlocal in_dim, out_dim, dim_size 2110 args_flat, in_tree = tree_flatten(args) 2111 in_dim = (in_dim,) * len(args) if isinstance(in_dim, int) else in_dim 2112 out_dim = (out_dim,) * len(args) if isinstance(out_dim, int) else out_dim 2113 in_dim_flat, in_dim_tree = tree_flatten(in_dim) 2114 out_dim_flat, out_dim_tree = tree_flatten(out_dim) 2115 if not (in_tree == in_dim_tree == out_dim_tree): 2116 raise TypeError(f"\n{in_tree}\n!=\n{in_dim_tree}!=\n{out_dim_tree}") 2117 f_flat, out_tree_store = flatten_fn(f, in_tree) 2118 # if len(args_flat) > len(in_dim_flat): 2119 # in_dim_flat = (in_dim[0],) * len(args_flat) 2120 outs_flat = vmap_flat(f_flat, in_dim_flat, out_dim_flat, dim_size, *args_flat) 2121 return tree_unflatten(out_tree_store(), outs_flat) 2122 2123 return batched_f 2124 2125 2126def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args): 2127 with new_main_trace(JVPTrace, global_data) as main: 2128 trace = JVPTrace(main) 2129 tracers_in = [JVPTraceTensor(trace, x, t) for x, t in list_zip(primals, tangents)] 2130 jvp_flat_ret = f(*tracers_in, **static_args) 2131 if has_aux: 2132 (outs, aux) = jvp_flat_ret 2133 # aux_ = aux 2134 aux = tree_map(lambda x: x.primal, aux) 2135 # aux = tree_map(lambda x: x.full_lower(), aux) 2136 # 2137 else: 2138 outs = jvp_flat_ret 2139 tracers_out = [full_raise(trace, out) for out in outs] 2140 primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out) 2141 return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out) 2142 2143 2144def jvp(f, primals, tangents, *, has_aux=False, global_data=None, **static_args): 2145 primals_flat, in_tree = tree_flatten(primals) 2146 tangents_flat, in_tree2 = tree_flatten(tangents) 2147 for p, t in zip(primals_flat, tangents_flat): 2148 assert p.shape == t.shape, f"{p.shape=} != {t.shape=}" 2149 assert p.dtype == t.dtype, f"{p.dtype=} != {t.dtype=}" 2150 assert p.device == t.device, f"{p.device=} != {t.device=}" 2151 if in_tree != in_tree2: 2152 raise TypeError 2153 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2154 jvp_ret = jvp_flat( 2155 f, 2156 primals_flat, 2157 tangents_flat, 2158 has_aux=has_aux, 2159 global_data=global_data, 2160 **static_args, 2161 ) 2162 if has_aux: 2163 (primals_out_flat, tangents_out_flat), aux = jvp_ret 2164 else: 2165 (primals_out_flat, tangents_out_flat) = jvp_ret 2166 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2167 tangents_out = tree_unflatten(out_tree_store(), tangents_out_flat) 2168 return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out) 2169 2170 2171def jacfwd(f, argnums=0, has_aux=False): 2172 def jvp_fn(x): 2173 return jvp(f, x, (backend.eye(len(x)),), has_aux=has_aux) 2174 return vmap(jvp_fn, in_dim=argnums) 2175 2176def jacrev(f, argnums=0, has_aux=False): 2177 def grad_f(x): 2178 return grad(lambda x: f(x) @ backend.eye(f(x).shape[0]), argnums=argnums, has_aux=has_aux)(x) 2179 return vmap(grad_f) 2180 2181## arange version 2182# def jacrev(f, x): 2183# def grad_f_i(x, i): 2184# return grad(lambda x: f(x)[i])(x) 2185# return vmap(lambda i: grad_i(x, i))(backend.arange(f(x).shape[0])) 2186 2187def hessian(fn, argnums=0, has_aux=False): 2188 return jacrev(jacrev(fn, argnums=argnums, has_aux=has_aux)) 2189 2190@contextmanager 2191def stash_trace(main: MainTrace): 2192 global stashed_trace 2193 prev_stashed_trace, stashed_trace = stashed_trace, main 2194 try: 2195 yield 2196 finally: 2197 stashed_trace = prev_stashed_trace 2198 2199 2200@contextmanager 2201def symbolic_run(): 2202 global trace_stack 2203 level = len(trace_stack) 2204 main = MainTrace(level, SymbolicRunTrace, global_data=None) 2205 trace_stack += [main] 2206 global stashed_trace 2207 prev_stashed_trace, stashed_trace = stashed_trace, main 2208 try: 2209 yield 2210 finally: 2211 stashed_trace = prev_stashed_trace 2212 trace_stack.pop() 2213 2214 2215@lru_cache_verbose() 2216def make_program(f: Callable, *symvals_in: SymbolicTensor, static_args, name) -> Tuple[Program, List[Any], TreeDef]: 2217 symvals_in, in_tree = tree_flatten(symvals_in) 2218 f, out_tree_store = flatten_fn(f, in_tree) 2219 builder = ProgramBuilder() 2220 with new_main_trace(ProgramTrace, builder) as main: 2221 with stash_trace(main): 2222 trace = ProgramTrace(main) 2223 tracers_in = [trace.new_arg(symval) for symval in symvals_in] 2224 outs = f(*tracers_in, **{k: v for k, v in static_args}) 2225 tracers_out = [full_raise(trace, out) if isinstance(out, ProgramTraceTensor) else out.val for out in outs] 2226 program, consts = builder.build(tracers_in, tracers_out, static_args, name) 2227 2228 return program, consts, out_tree_store() 2229 2230 2231@lru_cache_verbose() 2232def vmap_program(program: Program, dim_size, dims_in) -> tuple[Program, list[Any]]: 2233 def unmapped_symval(axis_size: int, batch_dim, symval: SymbolicTensor) -> SymbolicTensor: 2234 if batch_dim is None: 2235 return symval 2236 else: 2237 shape = list(symval.shape) 2238 shape.insert(batch_dim, axis_size) 2239 return symval.like(shape=tuple(shape)) 2240 2241 vmap_traceable = vmap(program_as_fun(program), tuple(dims_in)) 2242 in_symvals = [unmapped_symval(dim_size, d, v.symval) for v, d in zip(program.in_binders, dims_in)] 2243 program, consts, _ = make_program( 2244 vmap_traceable, 2245 *in_symvals, 2246 static_args=program.static_args, 2247 name=f"vmap_{program.name}", 2248 ) 2249 return program, consts 2250 2251 2252@lru_cache_verbose() 2253def jvp_program(program: Program) -> Tuple[Program, List[Any]]: 2254 def jvp_traceable(*primals_and_tangents): 2255 n = len(primals_and_tangents) // 2 2256 primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:] 2257 return jvp(program_as_fun(program), primals, tangents) 2258 2259 in_symvals = tree_map(lambda v: v.symval, program.in_binders) 2260 new_program, new_consts, _ = make_program( 2261 jvp_traceable, 2262 *in_symvals, 2263 *in_symvals, 2264 static_args=program.static_args, 2265 name=f"{program.name}_jvp", 2266 ) 2267 return new_program, new_consts 2268 2269 2270def partial_run_flat( 2271 f: Callable, pvals_in: List["PartialValue"], has_aux, global_data=None 2272) -> Tuple[Program, List["PartialValue"], List[Any]]: 2273 with new_main_trace(PartialRunTrace, global_data) as main: 2274 trace = PartialRunTrace(main) 2275 tracers_in = [trace.new_arg(pval) for pval in pvals_in] 2276 outs = f(*tracers_in) 2277 if has_aux: 2278 outs, aux = outs 2279 tracers_out = [full_raise(trace, out) for out in outs] 2280 pvals_out = [t.pval for t in tracers_out] 2281 unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown] 2282 unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown] 2283 program, consts = tracers_to_program(unk_tracers_in, unk_tracers_out) 2284 2285 return (program, pvals_out, consts, aux) if has_aux else (program, pvals_out, consts) 2286 2287 2288def partial_run_program( 2289 program: Program, 2290 in_unknowns: List[bool], 2291 instantiate: Optional[List[bool]] = None, 2292) -> Tuple[Program, Program, List[bool], int]: 2293 env: Dict[Var, bool] = {} 2294 residuals: Set[Var] = set() 2295 2296 def read(x: Atom) -> bool: 2297 return type(x) is Var and env[x] 2298 2299 def write(unk: bool, v: Var) -> None: 2300 env[v] = unk 2301 2302 instructions1, instructions2 = [], [] 2303 list_map(write, in_unknowns, program.in_binders) 2304 2305 for instruction in program.instructions: 2306 unks_in = list_map(read, instruction.inputs) 2307 ( 2308 instruction1, 2309 instruction2, 2310 unks_out, 2311 res, 2312 ) = instruction.op.partial_run_instruction(unks_in, instruction) 2313 if instruction1 is not None: 2314 instructions1 += [instruction1] 2315 if instruction2 is not None: 2316 instructions2 += [instruction2] 2317 if res is not None: 2318 residuals.update(res) 2319 list_map(write, unks_out, instruction.out_binders) 2320 2321 out_unknowns = list_map(read, program.outs) 2322 if instantiate is not None: 2323 for v, uk, inst in zip(program.outs, out_unknowns, instantiate): 2324 if inst and not uk: 2325 if type(v) is Var: 2326 residuals.add(v) 2327 out_unknowns = list_map(operator_py.or_, out_unknowns, instantiate) 2328 2329 residuals, num_res = list(residuals), len(residuals) 2330 assert all(type(v) is Var for v in residuals), residuals 2331 2332 ins1, ins2 = partition_list(in_unknowns, program.in_binders) 2333 outs1, outs2 = partition_list(out_unknowns, program.outs) 2334 2335 program1 = Program( 2336 ins1, 2337 instructions1, 2338 outs1 + residuals, 2339 0, 2340 program.static_args, 2341 f"{program.name}_partial1", 2342 ) 2343 program2 = Program( 2344 residuals + ins2, 2345 instructions2, 2346 outs2, 2347 0, 2348 program.static_args, 2349 f"{program.name}_partial2", 2350 ) 2351 typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2) 2352 2353 return program1, program2, out_unknowns, num_res 2354 2355 2356def typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2): 2357 programty = typecheck_program(program) # (a1, a2) -> (b1, b2 ) 2358 program1ty = typecheck_program(program1) # a1 -> (b1, res) 2359 program2ty = typecheck_program(program2) # (res, a2) -> b2 2360 2361 a1, a2 = partition_list(in_unknowns, programty.in_types) 2362 b1, b2 = partition_list(out_unknowns, programty.out_types) 2363 b1_, res = split_list(program1ty.out_types, len(b1)) 2364 res_, a2_ = split_list(program2ty.in_types, len(res)) 2365 b2_ = program2ty.out_types 2366 2367 a1 = tuple(a1) 2368 a2, a2_ = tuple(a2), tuple(a2_) 2369 b1, b1_ = tuple(b1), tuple(b1_) 2370 b2, b2_ = tuple(b2), tuple(b2_) 2371 res, res_ = tuple(res), tuple(res_) 2372 2373 if program1ty.in_types != a1: 2374 raise TypeError 2375 if program2ty.out_types != b2: 2376 raise TypeError 2377 if b1 != b1_: 2378 raise TypeError 2379 if res != res_: 2380 raise TypeError 2381 if a2 != a2_: 2382 raise TypeError 2383 if b2 != b2_: 2384 raise TypeError 2385 2386 2387def linearize_flat(f, *primals_in, has_aux): 2388 pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in] 2389 2390 def f_jvp(*primals_tangents_in): 2391 jvp_ret = jvp(f, *split_half(primals_tangents_in), has_aux=has_aux) 2392 if has_aux: 2393 (primals_out, tangents_out), aux = jvp_ret 2394 return ((*primals_out, *tangents_out), aux) 2395 else: 2396 primals_out, tangents_out = jvp_ret 2397 return (*primals_out, *tangents_out) 2398 2399 partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux) 2400 if has_aux: 2401 program, pvals_out, consts, aux = partial_run_flat_ret 2402 else: 2403 program, pvals_out, consts = partial_run_flat_ret 2404 primal_pvals, _ = split_half(pvals_out) 2405 assert all(pval.is_known for pval in primal_pvals) 2406 primals_out = [pval.const for pval in primal_pvals] 2407 f_lin = lambda *tangents: run_program(program, [*consts, *tangents]) 2408 return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin) 2409 2410 2411def linearize(f, *primals_in, has_aux=False): 2412 primals_in_flat, in_tree = tree_flatten(primals_in) 2413 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2414 linearize_flat_ret = linearize_flat(f, *primals_in_flat, has_aux=has_aux) 2415 if has_aux: 2416 primals_out_flat, f_lin_flat, aux = linearize_flat_ret 2417 else: 2418 primals_out_flat, f_lin_flat = linearize_flat_ret 2419 2420 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2421 2422 def f_lin(*tangents_in): 2423 tangents_in_flat, in_tree2 = tree_flatten(tangents_in) 2424 if in_tree != in_tree2: 2425 raise TypeError 2426 tangents_out_flat = f_lin_flat(*tangents_in_flat) 2427 return tree_unflatten(out_tree_store(), tangents_out_flat) 2428 2429 return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin) 2430 2431 2432def tracers_to_program( 2433 tracers_in: List["PartialRunTraceTensor"], 2434 tracers_out: List["PartialRunTraceTensor"], 2435): 2436 def tracer_parents(t: PartialRunTraceTensor) -> List[PartialRunTraceTensor]: 2437 return t.draft.tracers_in if isinstance(t.draft, InstructionDraft) else [] 2438 2439 def draft_to_instruction(tracer_to_var: Dict[int, Var], draft: InstructionDraft) -> Instruction: 2440 inputs = [tracer_to_var[id(t)] for t in draft.tracers_in] 2441 out_binders = [Var(symval) for symval in draft.symvals_out] 2442 for t_ref, var in list_zip(draft.tracer_refs_out, out_binders): 2443 if t_ref() is not None: 2444 tracer_to_var[id(t_ref())] = var 2445 return Instruction(draft.prim, inputs, draft.params, out_binders) 2446 2447 tracer_to_var: Dict[int, Var] = {id(t): Var(SymbolicTensor.like(t.symval)) for t in tracers_in} 2448 constvar_to_val: Dict[int, Any] = {} 2449 constid_to_var: Dict[int, Var] = {} 2450 processed_instructions: Set[int] = set() 2451 instructions: List[Instruction] = [] 2452 for t in toposort(tracers_out, tracer_parents): 2453 if isinstance(t.draft, LambdaBindingDraft): 2454 assert id(t) in set(list_map(id, tracers_in)) 2455 elif isinstance(t.draft, ConstDraft): 2456 val = t.draft.val 2457 var = constid_to_var.get(id(val)) 2458 if var is None: 2459 symval = SymbolicTensor.like(get_symval(val)) 2460 var = constid_to_var[id(val)] = Var(symval) 2461 constvar_to_val[var] = val 2462 tracer_to_var[id(t)] = var 2463 elif isinstance(t.draft, InstructionDraft): 2464 if id(t.draft) not in processed_instructions: 2465 instructions += [draft_to_instruction(tracer_to_var, t.draft)] 2466 processed_instructions.add(id(t.draft)) 2467 else: 2468 raise TypeError(t.draft) 2469 2470 constvars, constvals = unzip2(constvar_to_val.items()) 2471 in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in] 2472 # out_vars = [tracer_to_var[id(t)] for t in tracers_out if id(t) in tracer_to_var] 2473 out_vars = [tracer_to_var[id(t)] for t in tracers_out] 2474 program = Program(tuple(in_binders), tuple(instructions), tuple(out_vars)) 2475 typecheck_program(program) 2476 return program, constvals 2477 2478 2479def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]): 2480 def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]): 2481 seen = set() 2482 for node in nodes: 2483 assert all(id(parent) in seen for parent in parents(node)) 2484 seen.add(id(node)) 2485 2486 def remove_duplicates(lst): 2487 seen = set() 2488 return [x for x in lst if id(x) not in seen and not seen.add(id(x))] 2489 2490 if not out_nodes: 2491 return [] 2492 out_nodes = remove_duplicates(out_nodes) 2493 2494 child_counts = {} 2495 stack = list(out_nodes) 2496 while stack: 2497 node = stack.pop() 2498 if id(node) in child_counts: 2499 child_counts[id(node)] += 1 2500 else: 2501 child_counts[id(node)] = 1 2502 stack.extend(parents(node)) 2503 for node in out_nodes: 2504 child_counts[id(node)] -= 1 2505 2506 sorted_nodes = [] 2507 childless_nodes = [node for node in out_nodes if not child_counts[id(node)]] 2508 while childless_nodes: 2509 node = childless_nodes.pop() 2510 sorted_nodes += [node] 2511 for parent in parents(node): 2512 if child_counts[id(parent)] == 1: 2513 childless_nodes += [parent] 2514 else: 2515 child_counts[id(parent)] -= 1 2516 2517 sorted_nodes = sorted_nodes[::-1] 2518 check_toposort(sorted_nodes, parents) 2519 return sorted_nodes 2520 2521 2522def vjp_flat(f, *primals_in, has_aux=False, **static_args): 2523 pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in] 2524 primal_pvals_in, tangent_pvals_in = split_half(pvals_in) 2525 del primal_pvals_in 2526 2527 def f_jvp(*primals_tangents_in): 2528 jvp_ret = jvp( 2529 f, 2530 *split_half(primals_tangents_in), 2531 has_aux=has_aux, 2532 global_data="vjp", 2533 **static_args, 2534 ) 2535 if has_aux: 2536 ((primals_out, tangents_out), aux) = jvp_ret 2537 else: 2538 (primals_out, tangents_out) = jvp_ret 2539 return ([*primals_out, *tangents_out], aux) if has_aux else [*primals_out, *tangents_out] 2540 2541 partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux, "vjp") 2542 if has_aux: 2543 program, pvals_out, consts, aux = partial_run_flat_ret 2544 else: 2545 program, pvals_out, consts = partial_run_flat_ret 2546 2547 primal_pvals, tangent_pvals = split_half(pvals_out) 2548 del tangent_pvals 2549 assert all(pval.is_known for pval in primal_pvals) 2550 primals_out_flat = [pval.const for pval in primal_pvals] 2551 transpose_inputs = consts + [UndefinedPrimal(t.symval) for t in tangent_pvals_in] 2552 2553 def f_vjp_flat(*cotangents): 2554 # return backward_pass(program, transpose_inputs, cotangents) 2555 undef_primals = tuple(isinstance(x, UndefinedPrimal) for x in transpose_inputs) 2556 transposed_program, new_consts = transpose_program(program, undef_primals) 2557 residuals, _ = partition_list(undef_primals, transpose_inputs) 2558 outs = run_program(transposed_program, (*new_consts, *residuals, *cotangents)) 2559 return outs 2560 2561 return (primals_out_flat, f_vjp_flat, aux) if has_aux else (primals_out_flat, f_vjp_flat) 2562 2563 2564def vjp(f, *primals_in, has_aux=False, **static_args): 2565 primals_in_flat, in_tree = tree_flatten(primals_in) 2566 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2567 vjp_ret = vjp_flat(f, *primals_in_flat, has_aux=has_aux, **static_args) 2568 if has_aux: 2569 primals_out_flat, f_vjp_flat, aux = vjp_ret 2570 else: 2571 primals_out_flat, f_vjp_flat = vjp_ret 2572 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2573 2574 def f_vjp(*cotangents_out): 2575 cotangents_out_flat, _ = tree_flatten(cotangents_out) 2576 cotangents_in_flat = f_vjp_flat(*cotangents_out_flat) 2577 2578 return tree_unflatten(in_tree, cotangents_in_flat) 2579 2580 return (primals_out, f_vjp, aux) if has_aux else (primals_out, f_vjp) 2581 2582 2583NullCotangent = None 2584 2585 2586def backward_pass(program: Program, args: List[Any], cotangents: List[Any]) -> List[Any]: 2587 primal_env: Dict[Var, Any] = {} 2588 ct_env: Dict[Var, Any] = {} 2589 2590 def read_primal(x: Atom) -> Any: 2591 return primal_env.get(x, UndefinedPrimal(x.symval)) if type(x) is Var else x.val 2592 2593 def write_primal(v: Var, val: Any) -> None: 2594 if type(val) is not UndefinedPrimal: 2595 primal_env[v] = val 2596 2597 def read_cotangent(v: Var) -> Any: 2598 return ct_env.pop(v, backend.zeros(v.symval.shape, v.symval.dtype)) 2599 2600 def write_cotangent(x: Atom, ct: Any): 2601 if type(x) is Var and ct is not NullCotangent: 2602 ct_env[x] = (ct_env[x] + ct) if x in ct_env else ct 2603 2604 list_map(write_primal, program.in_binders, args) 2605 list_map(write_cotangent, program.outs, cotangents) 2606 for instruction in program.instructions[::-1]: 2607 primals_in = list_map(read_primal, instruction.inputs) 2608 cotangents_in = list_map(read_cotangent, instruction.out_binders) 2609 inp, params = primals_in, instruction.params 2610 cotangents_out = instruction.op.T(cotangents_in, *inp, **params) 2611 list_map(write_cotangent, instruction.inputs, cotangents_out) 2612 2613 ret = [read_cotangent(v) for v, x in list_zip(program.in_binders, args) if isinstance(x, UndefinedPrimal)] 2614 return ret 2615 2616 2617@lru_cache_verbose() 2618def transpose_program(program: Program, undef_primals: tuple[bool, ...]) -> tuple[Program, list[Any]]: 2619 symvals_in, symvals_out = typecheck_program(program) 2620 traceable = partial(backward_pass, program) 2621 () 2622 args = [UndefinedPrimal(a) if u else a for a, u in zip(symvals_in, undef_primals)] 2623 trans_program, consts, _ = make_program( 2624 traceable, 2625 tuple(args), 2626 tuple(symvals_out), 2627 static_args=program.static_args, 2628 name=f"{program.name}_T", 2629 ) 2630 typecheck_program(trans_program) 2631 2632 return trans_program, consts 2633 2634 2635def grad(f, argnums=(0,), argnames="", has_aux=False, return_value=False): 2636 f, rejit = (f, False) if not isinstance(f, jit) else (f.f, True) 2637 if isinstance(argnums, int): 2638 argnums = (argnums,) 2639 2640 def gfn(x, *xs, **static_args): 2641 vjp_ret = vjp(f, x, *xs, has_aux=has_aux, **static_args) 2642 if has_aux: 2643 y, f_vjp, aux = vjp_ret 2644 else: 2645 y, f_vjp = vjp_ret 2646 if np.shape(y) != (): 2647 raise TypeError("grad output must be 0-dim scalar with shape ()") 2648 gL_xs = f_vjp(backend.ones(())) 2649 gL_xs = tuple(gL_xs[i] for i in argnums) if len(argnums) > 1 else gL_xs[argnums[0]] 2650 if return_value: 2651 return ((y, aux), gL_xs) if has_aux else (y, gL_xs) 2652 else: 2653 return (gL_xs, aux) if has_aux else gL_xs 2654 2655 return jit(gfn) if rejit else gfn 2656 2657 2658def value_and_grad(f, argnums=(0,), argnames="", has_aux=False): 2659 return grad( 2660 f, 2661 argnums=argnums, 2662 argnames=argnames, 2663 has_aux=has_aux, 2664 return_value=True, 2665 ) 2666 2667 2668def jit_partial_run(trace, tracers, *, program): 2669 in_unknowns = [not t.pval.is_known for t in tracers] 2670 program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns) 2671 known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) 2672 known_vals = [t.pval.const for t in known_tracers] 2673 outs1_res = backend.jit_op(*known_vals, program=program) 2674 outs1, res = split_list(outs1_res, len(program1.outs) - num_res) 2675 res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] 2676 outs2 = [PartialRunTraceTensor(trace, PartialValue.unknown(v.symval), None) for v in program2.outs] 2677 draft = InstructionDraft( 2678 backend.jit_op, 2679 res_tracers + unknown_tracers, 2680 dict(program=program2), 2681 [v.symval for v in program2.outs], 2682 map(weakref.ref, outs2), 2683 ) 2684 for t in outs2: 2685 t.draft = draft 2686 return merge_lists(out_unknowns, outs1, outs2) 2687 2688 2689class jit: 2690 def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None): 2691 if isinstance(static_argnames, str): 2692 static_argnames = tuple(static_argnames.split(" ")) 2693 assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames) 2694 self.f = f 2695 self.name = name if name is not None else self.f.__name__ 2696 self.static_argnames = static_argnames 2697 self.dynamic_axes = dynamic_axes 2698 2699 @classmethod 2700 def with_options(cls, **kwargs): 2701 return partial(cls, **kwargs) 2702 2703 @classmethod 2704 def get_jit_name(cls, args, static_args, prefix="jit", short=False): 2705 name = f"{prefix}_" 2706 if short: 2707 static_args_tup = tuple(static_args.items()) 2708 ids = repr(hash((prefix, args, static_args_tup)))[-4:] 2709 name = f"{prefix}_{ids}" 2710 else: 2711 for a in args: 2712 name += f"shape_{a.shape}_dtype_{a.dtype.name}_" 2713 for k, v in static_args.items(): 2714 name += f"{k}_{v}_" 2715 name = name.replace("(", "L") 2716 name = name.replace(")", "R") 2717 name = name.replace(",", "C") 2718 name = name.replace(" ", "") 2719 name = name.replace(".", "D") 2720 2721 return name 2722 2723 def get_program(self, *args, **static_args): 2724 sig = inspect.signature(self.f) 2725 if all("*" not in repr(v) for v in sig.parameters.values()): 2726 args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames] 2727 static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames] 2728 2729 if args: 2730 if len(args) > len(args_strs): 2731 assert static_args_strs 2732 args, rest = args[: len(args_strs)], args[len(args_strs) :] 2733 new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args} 2734 static_args = {**new_static_args, **static_args} 2735 else: 2736 args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)]) 2737 2738 symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args) 2739 static_args = tuple(static_args.items()) 2740 if self.name is None: 2741 self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}" 2742 program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name) 2743 return program, consts, out_tree 2744 2745 def __call__(self, *args, **static_args): 2746 program, consts, out_tree = self.get_program(*args, **static_args) 2747 args, in_tree = tree_flatten(args) 2748 outs = bind(backend.jit_op, *consts, *args, program=program) 2749 return tree_unflatten(out_tree, outs) 2750 2751 def lower(self, *args, **static_args): 2752 program, consts, out_tree = self.get_program(*args, **static_args) 2753 args, in_tree = tree_flatten(args) 2754 hashed_program = Hashed(program) 2755 num_consts = program.num_consts 2756 consts, args = args[:num_consts], args[num_consts:] 2757 hashed_consts = tuple(map(Hashed, consts)) 2758 jit_output = backend.jit_program(hashed_program, hashed_consts) 2759 return jit_output 2760 2761 def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs): 2762 if isinstance(args, Tensor): 2763 args, static_args = (args,), dict() 2764 elif not isinstance(args[-1], dict): 2765 assert all(isinstance(a, Tensor) for a in args) 2766 static_args = dict() 2767 else: 2768 args, static_args = args[:-1], args[-1] 2769 assert isinstance(args, (tuple, list)) and isinstance(static_args, dict) 2770 jit_output = self.lower(*args, **static_args) 2771 backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)
46class Timing(ContextDecorator): 47 def __init__(self, prefix="", on_exit=None, enabled=True): 48 self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled 49 50 def __enter__(self): 51 self.st = time.perf_counter_ns() 52 53 def __exit__(self, *exc): 54 self.et = time.perf_counter_ns() - self.st 55 if self.enabled: 56 print(f"{self.prefix}{self.et*1e-6:6.2f} ms" + (self.on_exit(self.et) if self.on_exit else ""))
A base class or mixin that enables context managers to work as decorators.
59def colored(st, color: Optional[str], background=False): 60 return ( 61 f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" 62 if color is not None 63 else st 64 ) # replace the termcolor library with one line # noqa: E501
71class Profiling(ContextDecorator): 72 def __init__(self, enabled=True, sort="cumtime", frac=0.2, fn=None, ts=1): 73 self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3 / ts 74 75 def __enter__(self): 76 self.pr = cProfile.Profile() 77 if self.enabled: 78 self.pr.enable() 79 80 def __exit__(self, *exc): 81 if self.enabled: 82 self.pr.disable() 83 if self.fn: 84 self.pr.dump_stats(self.fn) 85 stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort) 86 for fcn in stats.fcn_list[0 : int(len(stats.fcn_list) * self.frac)]: # type: ignore[attr-defined] 87 (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined] 88 scallers = sorted(callers.items(), key=lambda x: -x[1][2]) 89 print( 90 f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms", 91 colored(_format_fcn(fcn), "yellow") + " " * (50 - len(_format_fcn(fcn))), 92 colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else "", 93 )
A base class or mixin that enables context managers to work as decorators.
151def lru_cache_verbose( 152 maxsize: int = 100, 153 typed: bool = False, 154 tb_start: int = -12, 155 tb_end: int = -7, 156): 157 def decorator(fn: Callable): 158 @lru_cache(maxsize=maxsize, typed=typed) 159 def wrapper(*args, **kwargs) -> Callable: 160 return fn(*args, **kwargs) 161 162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result 174 175 decorated_function.cache_info = wrapper.cache_info 176 decorated_function.fn = fn 177 return decorated_function 178 179 return decorator
182def cuda_is_available(): 183 try: 184 import subprocess 185 import platform 186 187 cmd = f"nvidia-smi{'.exe' if platform.system == 'Windows' else ''}" 188 result = subprocess.run([cmd], stdout=subprocess.PIPE) 189 output = result.stdout.decode("utf-8") 190 return True if "NVIDIA-SMI" in output else False 191 except FileNotFoundError: 192 return False
195class Hashed: 196 val: Any 197 198 def __init__(self, val): 199 self.val = val 200 201 def __hash__(self) -> int: 202 return hash((self.val,)) 203 204 def __eq__(self, other): 205 if isinstance(other, Hashed): 206 if isinstance(self.val, Tensor) and isinstance(other.val, Tensor): 207 # because Tensor.__eq__ already for Tensor.equal 208 return id(self.val) == id(other.val) 209 return self.val == other.val 210 return False 211 212 def __repr__(self): 213 return f"Hashed: {repr(self.val)}"
221class DType(NamedTuple): 222 priority: int 223 itemsize: int 224 name: str 225 mlir: str 226 numpy: type 227 228 @property 229 def format_code(self): 230 return f"slope.{self.name}" 231 232 def __repr__(self): 233 return f"<DType: {self.name}>"
DType(priority, itemsize, name, mlir, numpy)
Create new instance of DType(priority, itemsize, name, mlir, numpy)
Inherited Members
- builtins.tuple
- index
- count
236class dtypes: 237 float32: Final[DType] = DType(4, 4, "float32", "f32", np.float32) 238 uint8: Final[DType] = DType(0, 1, "uint8", "u8", np.uint8) 239 int8: Final[DType] = DType(0, 1, "int8", "i8", np.int8) 240 bool: Final[DType] = DType(0, 1, "bool", "i1", bool) 241 int32: Final[DType] = DType(1, 4, "int32", "i32", np.int32) 242 int64: Final[DType] = DType(2, 8, "int64", "i64", np.int64) 243 uint64: Final[DType] = DType(2, 8, "uint64", "ui64", np.uint64) 244 float16: Final[DType] = DType(0, 2, "float16", "f16", np.float16) 245 half = float16 246 # bfloat16: Final[DType] = DType(0, 2, "bfloat16", "bf16", np.float16) 247 248 all_dtypes = (bool, float16, float32, int8, int32, int64, uint8, uint64) 249 name_dtype_map = {k.name: k for k in all_dtypes} 250 name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()} 251 mlir_dtype_map = {k.mlir: k for k in all_dtypes} 252 mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()} 253 254 @classmethod 255 def is_int(cls, dtype): 256 return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64) 257 258 @classmethod 259 def is_float(cls, dtype): 260 return dtype in (cls.float16, cls.bfloat16, cls.float32)
263class Device(NamedTuple): 264 name: str 265 idx: int 266 267 @property 268 def format_code(self): 269 return f"'{self.name}:{self.idx}'" 270 271 def __repr__(self): 272 return f"<Device: {self.format_code}>"
Device(name, idx)
Inherited Members
- builtins.tuple
- index
- count
275class devices: 276 cpu: Final[Device] = Device("cpu", 0) 277 metal: Final[Device] = Device("metal", 0) 278 cuda0: Final[Device] = Device("cuda", 0) 279 # TODO: programmatically define this class attrs to support other setup 280 cuda = cuda0 281 all_devices = (cpu, metal, cuda0) 282 name_idx_device_map = {f"{k.name}:{k.idx}": k for k in all_devices} 283 name_idx_device_map_inv = {v: k for k, v in name_idx_device_map.items()}
291class Tensor: 292 def __init__(self, val: TensorBuffer): 293 assert isinstance(val, TensorBuffer) 294 self.buf = val 295 296 @property 297 def symval(self): 298 return SymbolicTensor.like(self) 299 300 @property 301 def default_dtype(self): 302 return backend.default_dtype 303 304 def is_int(self) -> bool: 305 return self.dtype in ( 306 dtypes.int8, 307 dtypes.uint8, 308 dtypes.uint64, 309 dtypes.int32, 310 dtypes.int64, 311 ) 312 313 def is_float(self) -> bool: 314 return self.dtype in (dtypes.float16, dtypes.float32) 315 316 def is_unsigned(self) -> bool: 317 return self.dtype is dtypes.uint8 318 319 def to_bool(self): 320 return self.cast(dtypes.bool) 321 322 def short(self): 323 return self.cast(dtypes.int8) 324 325 def int(self): 326 return self.cast(dtypes.int32) 327 328 def long(self): 329 return self.cast(dtypes.int64) 330 331 def half(self): 332 return self.cast(dtypes.float16) 333 334 def float(self): 335 return self.cast(dtypes.float32) 336 337 def __getattr__(self, attr): 338 if attr in vars(backend.operator_set).keys(): 339 op = getattr(backend.operator_set, attr) 340 return partial(op, self) 341 elif attr in vars(backend.procedure_set).keys(): 342 procedure = getattr(backend.procedure_set, attr) 343 assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}" 344 return partial(procedure, self) 345 else: 346 return self.__getattribute__(attr) 347 348 def __getitem__(self, idx): 349 return self.getitem(idx) 350 351 def __setitem__(self, idx, item): 352 raise NotImplementedError 353 354 def str_short(self): 355 return f"<Tensor: shape={self.shape}, dtype={self.dtype}>" 356 357 __neg__ = lambda self: self.neg() 358 __add__ = lambda self, other: self.add(other) 359 __radd__ = lambda self, other: self.add(other) 360 __sub__ = lambda self, other: self.sub(other) 361 __rsub__ = lambda self, other: self.sub.func(other, self) 362 __mul__ = lambda self, other: self.mul(other) 363 __rmul__ = lambda self, other: self.mul(other) 364 __div__ = lambda self, other: self.div(other) 365 __rdiv__ = lambda self, other: self.div.func(other, self) 366 __truediv__ = __div__ 367 __truerdiv__ = __rdiv__ 368 __pow__ = lambda self, other: self.pow(other) 369 __rpow__ = lambda self, other: self.pow.func(other, self) 370 __matmul__ = lambda self, other: self.matmul(other) 371 __rmatmul__ = lambda self, other: self.matmul.func(other, self) 372 __invert__ = lambda self: self.invert() 373 __eq__ = lambda self, other: self.equal(other) 374 __ne__ = lambda self, other: self.not_equal(other) 375 __ge__ = lambda self, other: self.greater_equal(other) 376 __le__ = lambda self, other: self.less_equal(other) 377 __gt__ = lambda self, other: self.greater(other) 378 __lt__ = lambda self, other: self.less(other) 379 380 def __hash__(self): 381 return id(self.val) 382 383 val = property(lambda self: self.buf.val) 384 385 def size(self, i): 386 return self.shape[i] 387 388 @property 389 def dtype(self): 390 return backend.dtype_of(self) 391 392 @property 393 def device(self): 394 return backend.device_of(self) 395 396 def numpy(self, memmap=False): 397 return backend.numpy_of(self, memmap) 398 399 @property 400 def shape(self): 401 return backend.shape_of(self) 402 403 @property 404 def ndim(self): 405 return len(self.shape) 406 407 def numel(self): 408 return math.prod(self.shape) 409 410 def element_size(self): 411 return self.dtype.itemsize 412 413 def nbytes(self): 414 return self.numel() * self.element_size() 415 416 def __repr__(self): 417 return f"<Tensor: val=\n{self.numpy()}\nshape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}>"
420class SymbolicTensor(Tensor): 421 def __init__(self, shape, dtype, device): 422 assert isinstance(dtype, DType) 423 self._shape = tuple(int(i) for i in shape) 424 self._dtype = dtype 425 self._device = device 426 427 @property 428 def symval(self): 429 return self 430 431 @property 432 def val(self): 433 raise RuntimeError(f"SymbolicTensor actually has no val, from {trace_stack[-1]=}, ") 434 435 @property 436 def shape(self): 437 return self._shape 438 439 @property 440 def dtype(self): 441 return self._dtype 442 443 @property 444 def device(self): 445 return self._device 446 447 def like(self, **overrides): 448 shape = overrides.get("shape", self.shape) 449 dtype = overrides.get("dtype", self.dtype) 450 device = overrides.get("device", self.device) 451 return SymbolicTensor(shape, dtype, device) 452 453 def str_short(self): 454 return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]' 455 456 def __hash__(self): 457 return hash((self.shape, self.dtype)) 458 459 def __eq__(self, other): 460 if type(self) != type(other): 461 return False 462 return (self.shape == other.shape) and (self.dtype == other.dtype) 463 464 def __repr__(self): 465 return f"<SymbolicTensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device}>"
473class Operator: 474 def __init__(self, name, variadic_inputs=False, nary_outputs=False): 475 self.name = name 476 self.variadic_inputs = variadic_inputs 477 self.nary_outputs = nary_outputs 478 if self.variadic_inputs: 479 self.reorg_args = self.reorg_args_nary 480 481 def __hash__(self): 482 return hash(self.name) 483 484 def __eq__(self, other): 485 if not isinstance(other, Operator): 486 return False 487 return self.name == other.name 488 489 def args_fixer(self, *args, **params): 490 return args, params 491 492 def __call__(self, *args, **params): 493 args, params = self.reorg_args(args, params) 494 args, params = self.args_fixer(*args, **params) 495 ret = bind(self, *args, **params) 496 if not self.nary_outputs: 497 ret = ret[0] 498 return ret 499 500 def __repr__(self) -> str: 501 return f"<{self.name}>" 502 503 def typecheck(self, *args, **params): 504 raise NotImplementedError 505 506 def jvp(self, *args, **params): 507 raise NotImplementedError 508 509 def T(self, *args, **params): 510 raise NotImplementedError 511 512 def vmap(self, *args, **params): 513 raise NotImplementedError 514 515 def reorg_args(self, args, params): 516 sig = inspect.signature(self.typecheck) 517 args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"] 518 params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"] 519 520 if args: 521 if len(args) > len(args_strs): 522 args, rest = args[: len(args_strs)], args[len(args_strs) :] 523 if params_strs: 524 new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params} 525 params = {**new_params, **params} 526 else: 527 args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)]) 528 assert len(args) == len(args_strs) 529 return args, params 530 531 def reorg_args_nary(self, args, params): 532 return args, params 533 534 def partial_run(self, trace, tracers, **params): 535 tracers_in = [trace.instantiate_const(t) for t in tracers] 536 symvals_in = [t.symval for t in tracers_in] 537 symvals_out = self.typecheck(*symvals_in, **params) 538 tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out] 539 instruction = InstructionDraft( 540 self, 541 tracers_in, 542 params, 543 symvals_out, 544 list_map(weakref.ref, tracers_out), 545 ) 546 for t in tracers_out: 547 t.draft = instruction 548 return tracers_out 549 550 def partial_run_instruction(self, unks_in, instruction): 551 if any(unks_in): 552 instruction1 = None 553 instruction2 = Instruction( 554 instruction.op, 555 instruction.inputs, 556 instruction.params, 557 instruction.out_binders, 558 ) 559 unks_out = [True for i in instruction.out_binders] 560 res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)] 561 else: 562 instruction1 = instruction 563 instruction2 = None 564 unks_out = [False for i in instruction.out_binders] 565 res = None 566 567 return instruction1, instruction2, unks_out, res
515 def reorg_args(self, args, params): 516 sig = inspect.signature(self.typecheck) 517 args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"] 518 params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"] 519 520 if args: 521 if len(args) > len(args_strs): 522 args, rest = args[: len(args_strs)], args[len(args_strs) :] 523 if params_strs: 524 new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params} 525 params = {**new_params, **params} 526 else: 527 args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)]) 528 assert len(args) == len(args_strs) 529 return args, params
534 def partial_run(self, trace, tracers, **params): 535 tracers_in = [trace.instantiate_const(t) for t in tracers] 536 symvals_in = [t.symval for t in tracers_in] 537 symvals_out = self.typecheck(*symvals_in, **params) 538 tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out] 539 instruction = InstructionDraft( 540 self, 541 tracers_in, 542 params, 543 symvals_out, 544 list_map(weakref.ref, tracers_out), 545 ) 546 for t in tracers_out: 547 t.draft = instruction 548 return tracers_out
550 def partial_run_instruction(self, unks_in, instruction): 551 if any(unks_in): 552 instruction1 = None 553 instruction2 = Instruction( 554 instruction.op, 555 instruction.inputs, 556 instruction.params, 557 instruction.out_binders, 558 ) 559 unks_out = [True for i in instruction.out_binders] 560 res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)] 561 else: 562 instruction1 = instruction 563 instruction2 = None 564 unks_out = [False for i in instruction.out_binders] 565 res = None 566 567 return instruction1, instruction2, unks_out, res
570class MetaOperator(Operator): 571 def meta_impl(self, *args, **kwargs): 572 raise NotImplementedError
575class UnaryOperator(Operator): 576 def vmap(self, x, *, dim_size, vals_in, dims_in, **params): 577 (x,), (x_bdim,) = vals_in, dims_in 578 return [self(x, **params)], [x_bdim] 579 580 def typecheck(self, x, **params): 581 return [SymbolicTensor.like(x)] 582 583 def jvp(self, primals, tangents, **params): 584 (x,), (x_dot,) = primals, tangents 585 return [self(x, **params)], [self(x_dot, **params)]
588class BinaryOperator(Operator): 589 boolean_output = False 590 591 def args_fixer(self, x, w, **params): 592 if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal: 593 assert x.shape == w.shape 594 return (x, w), params 595 596 if type(x) in TraceTensor.PYTHON_TYPES: 597 x = backend.full(shape=(), fill_value=x, dtype=w.dtype) 598 elif type(w) in TraceTensor.PYTHON_TYPES: 599 w = backend.full(shape=(), fill_value=w, dtype=x.dtype) 600 601 shape_delta = x.ndim - w.ndim 602 if shape_delta > 0: 603 w = w.reshape((1,) * shape_delta + w.shape) 604 elif shape_delta < 0: 605 x = x.reshape((1,) * -shape_delta + x.shape) 606 607 shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)]) 608 if x.shape != shape_ret: 609 x = x.expand(shape_ret) 610 if w.shape != shape_ret: 611 w = w.expand(shape_ret) 612 613 if type(x) is Tensor and isinstance(w, TraceTensor): 614 x = w._trace.pure(x) 615 elif type(w) is Tensor and isinstance(x, TraceTensor): 616 w = x._trace.pure(w) 617 # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html 618 if x.dtype != w.dtype: 619 # {int, bool} -> float 620 if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype): 621 if dtypes.is_float(w.dtype): 622 x = x.cast(w.dtype) 623 elif dtypes.is_float(x.dtype): 624 w = w.cast(x.dtype) 625 # bool -> int 626 elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype): 627 if dtypes.is_int(w.dtype): 628 x = x.cast(w.dtype) 629 elif dtypes.is_int(x.dtype): 630 w = w.cast(x.dtype) 631 else: # TODO: fine-grained type promotions 632 raise NotImplementedError("No other type promotion rules") 633 634 return (x, w), params 635 636 def vmap(self, dim_size, vals_in, dims_in, **params): 637 (x, w), (x_bdim, w_bdim) = vals_in, dims_in 638 if x_bdim != w_bdim: 639 if x_bdim is None: 640 x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim) 641 x_bdim = w_bdim 642 else: 643 w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim) 644 return [self(x, w, **params)], [x_bdim] 645 646 def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]: 647 if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)): 648 raise TypeError 649 symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype) 650 symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype) 651 if x.dtype != y.dtype: 652 raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})") 653 if symx == symy: 654 return [symx] 655 shape_delta = len(symx.shape) - len(symy.shape) 656 if shape_delta > 0: 657 symy = symy.like(shape=(1,) * shape_delta + symy.shape) 658 elif shape_delta < 0: 659 symx = symx.like(shape=(1,) * -shape_delta + symx.shape) 660 if symx == symy: 661 return [symx] 662 else: 663 shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)]) 664 if symx.shape != shape_ret: 665 symx = symx.like(shape=shape_ret) 666 if symy.shape != shape_ret: 667 symy = symx.like(shape=shape_ret) 668 if symx != symy: 669 raise TypeError(f"symx ({symx}) != symy ({symy})") 670 return [symx] 671 672 def jvp(self, primals, tangents, **params): 673 (x, w), (x_dot, w_dot) = primals, tangents 674 return [self(x, w, **params)], [self(x_dot, w_dot, **params)] 675 676 def T(self, cotangents, x, w): 677 (gL_y,) = cotangents 678 if self.boolean_output: 679 gL_y = gL_y.cast(x.dtype) 680 if isinstance(x, UndefinedPrimal): 681 return [gL_y, NullCotangent] 682 elif isinstance(w, UndefinedPrimal): 683 return [NullCotangent, gL_y] 684 else: 685 raise ValueError
591 def args_fixer(self, x, w, **params): 592 if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal: 593 assert x.shape == w.shape 594 return (x, w), params 595 596 if type(x) in TraceTensor.PYTHON_TYPES: 597 x = backend.full(shape=(), fill_value=x, dtype=w.dtype) 598 elif type(w) in TraceTensor.PYTHON_TYPES: 599 w = backend.full(shape=(), fill_value=w, dtype=x.dtype) 600 601 shape_delta = x.ndim - w.ndim 602 if shape_delta > 0: 603 w = w.reshape((1,) * shape_delta + w.shape) 604 elif shape_delta < 0: 605 x = x.reshape((1,) * -shape_delta + x.shape) 606 607 shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)]) 608 if x.shape != shape_ret: 609 x = x.expand(shape_ret) 610 if w.shape != shape_ret: 611 w = w.expand(shape_ret) 612 613 if type(x) is Tensor and isinstance(w, TraceTensor): 614 x = w._trace.pure(x) 615 elif type(w) is Tensor and isinstance(x, TraceTensor): 616 w = x._trace.pure(w) 617 # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html 618 if x.dtype != w.dtype: 619 # {int, bool} -> float 620 if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype): 621 if dtypes.is_float(w.dtype): 622 x = x.cast(w.dtype) 623 elif dtypes.is_float(x.dtype): 624 w = w.cast(x.dtype) 625 # bool -> int 626 elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype): 627 if dtypes.is_int(w.dtype): 628 x = x.cast(w.dtype) 629 elif dtypes.is_int(x.dtype): 630 w = w.cast(x.dtype) 631 else: # TODO: fine-grained type promotions 632 raise NotImplementedError("No other type promotion rules") 633 634 return (x, w), params
636 def vmap(self, dim_size, vals_in, dims_in, **params): 637 (x, w), (x_bdim, w_bdim) = vals_in, dims_in 638 if x_bdim != w_bdim: 639 if x_bdim is None: 640 x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim) 641 x_bdim = w_bdim 642 else: 643 w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim) 644 return [self(x, w, **params)], [x_bdim]
646 def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]: 647 if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)): 648 raise TypeError 649 symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype) 650 symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype) 651 if x.dtype != y.dtype: 652 raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})") 653 if symx == symy: 654 return [symx] 655 shape_delta = len(symx.shape) - len(symy.shape) 656 if shape_delta > 0: 657 symy = symy.like(shape=(1,) * shape_delta + symy.shape) 658 elif shape_delta < 0: 659 symx = symx.like(shape=(1,) * -shape_delta + symx.shape) 660 if symx == symy: 661 return [symx] 662 else: 663 shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)]) 664 if symx.shape != shape_ret: 665 symx = symx.like(shape=shape_ret) 666 if symy.shape != shape_ret: 667 symy = symx.like(shape=shape_ret) 668 if symx != symy: 669 raise TypeError(f"symx ({symx}) != symy ({symy})") 670 return [symx]
688class ReduceOperator(Operator): 689 def args_fixer(self, x, *, dim=None, keepdim=False): 690 if dim is None: 691 dim = tuple(range(x.ndim)) 692 elif isinstance(dim, int): 693 dim = (dim,) 694 dim = tuple(a if a >= 0 else a + len(x.shape) for a in dim) 695 return (x,), dict(dim=dim, keepdim=keepdim) 696 697 def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim): 698 (x,), (x_bdim,) = vals_in, dims_in 699 dim = tuple(a + (x_bdim <= a) for a in dim) 700 out_bdim = x_bdim - sum(a < x_bdim for a in dim) 701 return [self(x, dim=dim, keepdim=keepdim)], [out_bdim] 702 703 def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]: 704 dim = list(set([a + len(x.shape) if a < 0 else a for a in dim])) 705 if keepdim: 706 new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)] 707 else: 708 new_shape = [d for i, d in enumerate(x.shape) if i not in dim] 709 return [SymbolicTensor.like(x, shape=tuple(new_shape))]
703 def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]: 704 dim = list(set([a + len(x.shape) if a < 0 else a for a in dim])) 705 if keepdim: 706 new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)] 707 else: 708 new_shape = [d for i, d in enumerate(x.shape) if i not in dim] 709 return [SymbolicTensor.like(x, shape=tuple(new_shape))]
Inherited Members
712class InitOperator(Operator): 713 def vmap(self, dim_size, vals_in, dims_in, **params): 714 (x_bdim,) = dims_in 715 y = self(**params) 716 y = y.unsqueeze(x_bdim) 717 return [y], [x_bdim] 718 719 def jvp(self, primals, tangents, **params): 720 y = self(**params) 721 y_dot = NullCotangent(y.symval) 722 return [y], [y_dot] 723 724 def T(self, cotangents, **params): 725 return [NullCotangent(cotangents[0])]
736class OperatorSet: 737 def __init__(self): 738 self.register("jit_op")(JitOp) 739 740 def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()): 741 def wrap(op_cls): 742 assert name not in vars(self) 743 op = op_cls(name, variadic_inputs, nary_outputs) 744 setattr(self, name, op) 745 for a in aliases: 746 setattr(self, a, op) 747 return op_cls 748 749 return wrap
740 def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()): 741 def wrap(op_cls): 742 assert name not in vars(self) 743 op = op_cls(name, variadic_inputs, nary_outputs) 744 setattr(self, name, op) 745 for a in aliases: 746 setattr(self, a, op) 747 return op_cls 748 749 return wrap
764class CodegenOutput(NamedTuple): 765 code_lines: List[str] 766 fn_defs: Dict[str, List[str]] 767 in_binders: List["ProgramEnvVar"] 768 outs: List["ProgramEnvVar"]
CodegenOutput(code_lines, fn_defs, in_binders, outs)
Create new instance of CodegenOutput(code_lines, fn_defs, in_binders, outs)
Inherited Members
- builtins.tuple
- index
- count
771class Backend: 772 LOG_LRU = int(os.environ.get("LOG_LRU", 0)) 773 LOG_JIT = int(os.environ.get("LOG_JIT", 0)) 774 LOG_TREE = int(os.environ.get("LOG_TREE", 0)) 775 LOG_BACKEND = int(os.environ.get("LOG_BACKEND", 0)) 776 LOG_PROGRAM = int(os.environ.get("LOG_PROGRAM", 0)) 777 LOG_INIT = int(os.environ.get("LOG_INIT", 1)) 778 device_var = os.environ.get("DEFAULT_DEVICE", "cpu:0") 779 if device_var[-2] != ":": 780 device_var += ":0" 781 DEFAULT_DEVICE = devices.name_idx_device_map[device_var] 782 DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")] 783 dtype_for_indices: DType = None # need to override 784 785 def __init__( 786 self, 787 operator_set: OperatorSet, 788 procedure_set: ProcedureSet, 789 ): 790 self.operator_set = operator_set 791 self.procedure_set = procedure_set 792 self.node_types = dict() 793 self.impls = dict() 794 self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple") 795 self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list") 796 self.register_node( 797 dict, 798 lambda d: list_map(tuple, unzip2(sorted(d.items()))), 799 lambda keys, vals: dict(list_zip(keys, vals)), 800 "dict", 801 ) 802 self.register_node( 803 UndefinedPrimal, 804 lambda u: (u.symval, ()), 805 lambda symval, _: UndefinedPrimal(symval), 806 "UndefinedPrimal", 807 ) 808 809 def set_impl(self, op: Union[types.LambdaType, types.FunctionType]): 810 def set_impl_(fn): 811 self.impls[op] = types.MethodType(fn, self) 812 813 return set_impl_ 814 815 def register_node(self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None: 816 if name is None: 817 name = str(ty) 818 self.node_types[ty] = NodeType(name, to_iter, from_iter) 819 820 def __getattr__(self, attr): 821 try: 822 dblog( 823 f"Looking {self}.{attr} in operator_set", 824 enable=backend.LOG_BACKEND, 825 ) 826 return getattr(self.operator_set, attr) 827 except: 828 pass 829 try: 830 dblog( 831 f"Looking {self}.{attr} in procedure_set", 832 enable=backend.LOG_BACKEND, 833 ) 834 return getattr(self.procedure_set, attr) 835 except: 836 pass 837 dblog( 838 f"Fallback to default {self} getattribute", 839 enable=backend.LOG_BACKEND, 840 ) 841 super().__getattribute__(attr) 842 843 def tensor( 844 self, 845 val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None, 846 dtype: Optional[Any] = None, 847 device=None, 848 ): 849 if isinstance(val, TensorBuffer): 850 return Tensor(val) 851 elif isinstance(val, Tensor): 852 return val 853 if type(val) is bytes: 854 val = np.frombuffer(val, dtype=dtype) 855 return self.from_numpy(val, dtype, device) 856 857 def symbolic_tensor( 858 self, 859 shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None, 860 dtype: Optional[Any] = None, 861 device=None, 862 ): 863 dtype = dtype or self.DEFAULT_DTYPE 864 device = device or self.DEFAULT_DEVICE 865 return SymbolicTensor(shape, dtype, device) 866 867 def seed(self, seed): 868 raise NotImplementedError 869 870 @property 871 def default_dtype_value(self): 872 return self.dtype_map[backend.DEFAULT_DTYPE] 873 874 def set_method(self, method): 875 setattr(self, method.__name__, types.MethodType(method, self)) 876 877 def from_numpy(self, val, device): 878 raise NotImplementedError 879 880 def numpy_of(self, tensor): 881 raise NotImplementedError 882 883 def device_of(self, tensor): 884 raise NotImplementedError 885 886 def shape_of(self, tensor): 887 raise NotImplementedError 888 889 def dtype_of(self, tensor): 890 raise NotImplementedError 891 892 @lru_cache_verbose() 893 def jit_program( 894 self, 895 hashed_program: Hashed, 896 hashed_consts: Tuple[Hashed, ...], 897 ): 898 program: Program = hashed_program.val 899 typecheck_program(program) 900 consts = [x.val for x in hashed_consts] 901 in_symvals = [v.symval for v in program.in_binders[len(consts) :]] 902 codegen_output: CodegenOutput = self.codegen(program, consts + in_symvals, fn_name="main") 903 fn, code = self.compile(codegen_output) 904 jit_output = JitOutput(program, codegen_output, fn, code, consts) 905 return jit_output 906 907 def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 908 "Returns compiler IR from the Program" 909 raise NotImplementedError 910 911 def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 912 "Compiles compiler IR to a Python callable function" 913 raise NotImplementedError 914 915 def export(self, jit_output, *args, **params): 916 raise NotImplementedError 917 918 def load(self, path, single_key="_tensor"): 919 with open(path, mode="rb") as f: 920 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m: 921 json_len = np.int64(m[0]) 922 start = 8 + json_len 923 metadata = json.loads(m[8:start]) 924 ret = {} 925 for k, v in metadata.items(): 926 if k != "__metadata__": 927 dtype = Tensor.mlir_dtype_map[(v["dtype"])] 928 data_start = start + v["data_offsets"][0] 929 data_end = start + v["data_offsets"][1] 930 t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy()) 931 t = backend.tensor(t_np, dtype=dtype) 932 t = t.reshape(tuple(v["shape"])) 933 ret[k] = t 934 if len(ret) == 1 and single_key in ret.keys(): 935 return ret[single_key] 936 return ret 937 938 def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"): 939 if isinstance(tensors, Tensor): 940 tensors = {single_key: tensors} 941 else: 942 assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items()) 943 944 metadata, offset = {}, 0 945 for k, v in tensors.items(): 946 metadata[k] = { 947 "dtype": v.dtype.mlir, 948 "shape": list(v.shape), 949 "data_offsets": [offset, offset + v.nbytes()], 950 } 951 offset += v.nbytes() 952 j = json.dumps(metadata, separators=(",", ":")) 953 Path(path).unlink(missing_ok=True) 954 jbytes = j.encode("utf-8") 955 start = 8 + len(jbytes) 956 with open(path, mode="wb") as f: # make empty file, fill with enough space 957 f.write(b"\x00" * (start + offset)) 958 with open(path, mode="r+b") as f: 959 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m: 960 m[0:8] = np.int64(len(j)).tobytes() 961 m[8:start] = jbytes 962 for t, tm in zip(tensors.values(), metadata.values()): 963 data_start, data_end = tm["data_offsets"] 964 m[start + data_start : start + data_end] = t.numpy().tobytes()
785 def __init__( 786 self, 787 operator_set: OperatorSet, 788 procedure_set: ProcedureSet, 789 ): 790 self.operator_set = operator_set 791 self.procedure_set = procedure_set 792 self.node_types = dict() 793 self.impls = dict() 794 self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple") 795 self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list") 796 self.register_node( 797 dict, 798 lambda d: list_map(tuple, unzip2(sorted(d.items()))), 799 lambda keys, vals: dict(list_zip(keys, vals)), 800 "dict", 801 ) 802 self.register_node( 803 UndefinedPrimal, 804 lambda u: (u.symval, ()), 805 lambda symval, _: UndefinedPrimal(symval), 806 "UndefinedPrimal", 807 )
843 def tensor( 844 self, 845 val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None, 846 dtype: Optional[Any] = None, 847 device=None, 848 ): 849 if isinstance(val, TensorBuffer): 850 return Tensor(val) 851 elif isinstance(val, Tensor): 852 return val 853 if type(val) is bytes: 854 val = np.frombuffer(val, dtype=dtype) 855 return self.from_numpy(val, dtype, device)
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
907 def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 908 "Returns compiler IR from the Program" 909 raise NotImplementedError
Returns compiler IR from the Program
911 def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str): 912 "Compiles compiler IR to a Python callable function" 913 raise NotImplementedError
Compiles compiler IR to a Python callable function
918 def load(self, path, single_key="_tensor"): 919 with open(path, mode="rb") as f: 920 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m: 921 json_len = np.int64(m[0]) 922 start = 8 + json_len 923 metadata = json.loads(m[8:start]) 924 ret = {} 925 for k, v in metadata.items(): 926 if k != "__metadata__": 927 dtype = Tensor.mlir_dtype_map[(v["dtype"])] 928 data_start = start + v["data_offsets"][0] 929 data_end = start + v["data_offsets"][1] 930 t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy()) 931 t = backend.tensor(t_np, dtype=dtype) 932 t = t.reshape(tuple(v["shape"])) 933 ret[k] = t 934 if len(ret) == 1 and single_key in ret.keys(): 935 return ret[single_key] 936 return ret
938 def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"): 939 if isinstance(tensors, Tensor): 940 tensors = {single_key: tensors} 941 else: 942 assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items()) 943 944 metadata, offset = {}, 0 945 for k, v in tensors.items(): 946 metadata[k] = { 947 "dtype": v.dtype.mlir, 948 "shape": list(v.shape), 949 "data_offsets": [offset, offset + v.nbytes()], 950 } 951 offset += v.nbytes() 952 j = json.dumps(metadata, separators=(",", ":")) 953 Path(path).unlink(missing_ok=True) 954 jbytes = j.encode("utf-8") 955 start = 8 + len(jbytes) 956 with open(path, mode="wb") as f: # make empty file, fill with enough space 957 f.write(b"\x00" * (start + offset)) 958 with open(path, mode="r+b") as f: 959 with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m: 960 m[0:8] = np.int64(len(j)).tobytes() 961 m[8:start] = jbytes 962 for t, tm in zip(tensors.values(), metadata.values()): 963 data_start, data_end = tm["data_offsets"] 964 m[start + data_start : start + data_end] = t.numpy().tobytes()
978class Lit: 979 def __init__(self, val): 980 self.symval = SymbolicTensor.like(get_symval(val)) 981 self.val = val
987class Instruction(NamedTuple): 988 op: Operator 989 inputs: List[Atom] 990 params: Dict[str, Any] 991 out_binders: List[Atom]
Instruction(op, inputs, params, out_binders)
Create new instance of Instruction(op, inputs, params, out_binders)
Inherited Members
- builtins.tuple
- index
- count
994class ProgramEnvVar(NamedTuple): 995 name: str 996 symval: SymbolicTensor 997 is_const: bool = False 998 999 @property 1000 def shape(self): 1001 return self.symval.shape 1002 1003 @property 1004 def dtype(self): 1005 return self.symval.dtype 1006 1007 @property 1008 def device(self): 1009 return self.symval.device 1010 1011 @property 1012 def ndim(self): 1013 return self.symval.ndim 1014 1015 def numpy(self): 1016 return self.symval.numpy() 1017 1018 def __repr__(self): 1019 return f"<ProgramEnvVar: name={self.name}, symval={self.symval}>" 1020 1021 str_short = __repr__
ProgramEnvVar(name, symval, is_const)
Create new instance of ProgramEnvVar(name, symval, is_const)
Inherited Members
- builtins.tuple
- index
- count
1024class Program: 1025 def __init__( 1026 self, 1027 in_binders: Any, 1028 instructions: Tuple[Instruction], 1029 outs: Any, 1030 num_consts: int = 0, 1031 static_args: Any = (), 1032 name: str = "my_program", 1033 indent_amount=4, 1034 ): 1035 self.in_binders: Any = in_binders 1036 self.outs: Any = outs 1037 self.instructions = self.prune_instructions(instructions, outs) 1038 self.num_consts: int = num_consts 1039 self.static_args = static_args 1040 self.name: str = name 1041 self.indent_amount: int = indent_amount 1042 1043 self.env: Dict[ProgramEnvVar, Any] = dict() 1044 for inb in self.in_binders: 1045 prefix = "x" if type(inb.symval) is SymbolicTensor else "c" 1046 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1047 self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False) 1048 for instruction in self.instructions: 1049 if len(instruction.out_binders) == 0: 1050 continue 1051 for outb in instruction.out_binders: 1052 prefix = "y" if outb in self.outs else "z" 1053 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1054 self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval) 1055 self.curr_repr = repr(self) 1056 1057 def pprint_shape(self, symval, scalar_as_empty_array=False): 1058 xdtype = symval.dtype.mlir 1059 if len(symval.shape) > 0: 1060 xshape = f"{', '.join((repr(i) for i in symval.shape))}" 1061 return f"[{xshape}, {xdtype}]" 1062 else: 1063 return f"[{xdtype}]" 1064 1065 def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False): 1066 in_code = ", ".join(self.pprint_shape(t) for t in in_symvals) 1067 in_code = f"({in_code})" if len(in_symvals) > 1 else in_code 1068 out_code = ", ".join(self.pprint_shape(t) for t in out_symvals) 1069 out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code 1070 typing_code = f"{in_code} -> {out_code}" 1071 return typing_code 1072 1073 def __repr__(self): 1074 fn_defs = self.instructions_as_code(self, dict()) 1075 return "\n".join(line for code_lines in fn_defs.values() for line in code_lines) 1076 1077 def save(self, *args, dir_path="/tmp/slope_program", dry_run=False): 1078 os.makedirs(dir_path, exist_ok=True) 1079 head_code_lines = [f"import slope # backend={backend.__class__.__name__}"] 1080 fn_defs = self.instructions_as_code(self, dict()) 1081 in_binders_vars = [self.env[i] for i in self.in_binders] 1082 for i in range(len(self.in_binders)): 1083 ibv = in_binders_vars[i] 1084 if ibv.is_const: 1085 const_filename = f"{ibv.name}.safetensors" 1086 const_path = os.path.join(dir_path, f"{const_filename}") 1087 if not dry_run: 1088 backend.save(args[i], const_path) 1089 dblog( 1090 f"Saved {ibv.name} at {const_path}", 1091 enable=backend.LOG_BACKEND, 1092 ) 1093 head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""] 1094 head_code_lines += [""] 1095 code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines]) 1096 dblog( 1097 f"Contents of {self.name}:\n```\n{code}\n```", 1098 enable=backend.LOG_BACKEND, 1099 ) 1100 program_path = os.path.join(dir_path, "main.py") 1101 if not dry_run: 1102 with open(program_path, "w") as f: 1103 f.write(code) 1104 dblog( 1105 f"Saved program {self.name} at {program_path}", 1106 enable=backend.LOG_BACKEND, 1107 ) 1108 ls_contents = "\n\t".join(os.listdir(dir_path)) 1109 dblog( 1110 f"Contents of {dir_path}:\n\t{ls_contents}", 1111 enable=backend.LOG_BACKEND, 1112 ) 1113 1114 def __hash__(self): 1115 return hash(self.curr_repr) 1116 1117 def __eq__(self, other): 1118 return self is other 1119 1120 @classmethod 1121 def instructions_as_code(cls, program, fn_defs): 1122 def indent(code, indent_amount): 1123 spaces = " " * (len(code) - len(code.lstrip())) 1124 spaces += " " * indent_amount 1125 return "\n".join([spaces + line for line in code.strip().split("\n")]) 1126 1127 in_binders_vars = [program.env[i] for i in program.in_binders] 1128 body_code_lines = [] 1129 for instruction in program.instructions: 1130 if len(instruction.out_binders) == 0: 1131 continue 1132 params = instruction.params.copy() 1133 for param_name, param in params.items(): 1134 if isinstance(param, Program): 1135 sub_program = param 1136 fn_defs = cls.instructions_as_code(sub_program, fn_defs) 1137 program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1138 params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]" 1139 if isinstance(param, DType): 1140 params[param_name] = f"slope.{param.name}" 1141 param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items()) 1142 in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1143 out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders) 1144 sig = program.pprint_sig( 1145 [program.env[x].symval for x in instruction.inputs], 1146 [program.env[y].symval for y in instruction.out_binders], 1147 ) 1148 line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}""" 1149 body_code_lines += [indent(line, program.indent_amount)] 1150 1151 fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars]) 1152 # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args]) 1153 out_vars = [program.env[o] for o in program.outs] 1154 fn_sig = program.pprint_sig( 1155 [i.symval for i in in_binders_vars], 1156 [o.symval for o in out_vars], 1157 ) 1158 head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"] 1159 out_str = ", ".join([f"{o.name}" for o in out_vars]) 1160 tail_code_line = [indent(f"return {out_str}", program.indent_amount)] 1161 code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"] 1162 1163 fn_defs[program.name] = code_lines 1164 return fn_defs 1165 1166 @staticmethod 1167 def prune_instructions(instructions, outs): 1168 graph = dict() 1169 for instruction in instructions: 1170 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1171 for parent in parent_nodes: 1172 if parent not in graph: 1173 graph[parent] = set() 1174 for child in child_nodes: 1175 graph[parent].add(child) 1176 visited_from_terminal = set() 1177 1178 def dfs(node, visited): 1179 visited.add(node) 1180 if node in graph: 1181 for neighbor in graph[node]: 1182 if neighbor not in visited: 1183 dfs(neighbor, visited) 1184 1185 for terminal_node in outs: 1186 dfs(terminal_node, visited_from_terminal) 1187 unreachable_nodes = set(graph.keys()) - visited_from_terminal 1188 1189 instructions_to_prune = [] 1190 for instruction in instructions: 1191 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1192 if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes): 1193 instructions_to_prune += [instruction] 1194 new_instructions = [inst for inst in instructions if inst not in instructions_to_prune] 1195 if backend.LOG_PROGRAM: 1196 LI = len(instructions) 1197 LNI = len(new_instructions) 1198 DIFF = LI - LNI 1199 UN = len(unreachable_nodes) 1200 dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}") 1201 return new_instructions
1025 def __init__( 1026 self, 1027 in_binders: Any, 1028 instructions: Tuple[Instruction], 1029 outs: Any, 1030 num_consts: int = 0, 1031 static_args: Any = (), 1032 name: str = "my_program", 1033 indent_amount=4, 1034 ): 1035 self.in_binders: Any = in_binders 1036 self.outs: Any = outs 1037 self.instructions = self.prune_instructions(instructions, outs) 1038 self.num_consts: int = num_consts 1039 self.static_args = static_args 1040 self.name: str = name 1041 self.indent_amount: int = indent_amount 1042 1043 self.env: Dict[ProgramEnvVar, Any] = dict() 1044 for inb in self.in_binders: 1045 prefix = "x" if type(inb.symval) is SymbolicTensor else "c" 1046 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1047 self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False) 1048 for instruction in self.instructions: 1049 if len(instruction.out_binders) == 0: 1050 continue 1051 for outb in instruction.out_binders: 1052 prefix = "y" if outb in self.outs else "z" 1053 idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()]) 1054 self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval) 1055 self.curr_repr = repr(self)
1065 def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False): 1066 in_code = ", ".join(self.pprint_shape(t) for t in in_symvals) 1067 in_code = f"({in_code})" if len(in_symvals) > 1 else in_code 1068 out_code = ", ".join(self.pprint_shape(t) for t in out_symvals) 1069 out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code 1070 typing_code = f"{in_code} -> {out_code}" 1071 return typing_code
1077 def save(self, *args, dir_path="/tmp/slope_program", dry_run=False): 1078 os.makedirs(dir_path, exist_ok=True) 1079 head_code_lines = [f"import slope # backend={backend.__class__.__name__}"] 1080 fn_defs = self.instructions_as_code(self, dict()) 1081 in_binders_vars = [self.env[i] for i in self.in_binders] 1082 for i in range(len(self.in_binders)): 1083 ibv = in_binders_vars[i] 1084 if ibv.is_const: 1085 const_filename = f"{ibv.name}.safetensors" 1086 const_path = os.path.join(dir_path, f"{const_filename}") 1087 if not dry_run: 1088 backend.save(args[i], const_path) 1089 dblog( 1090 f"Saved {ibv.name} at {const_path}", 1091 enable=backend.LOG_BACKEND, 1092 ) 1093 head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""] 1094 head_code_lines += [""] 1095 code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines]) 1096 dblog( 1097 f"Contents of {self.name}:\n```\n{code}\n```", 1098 enable=backend.LOG_BACKEND, 1099 ) 1100 program_path = os.path.join(dir_path, "main.py") 1101 if not dry_run: 1102 with open(program_path, "w") as f: 1103 f.write(code) 1104 dblog( 1105 f"Saved program {self.name} at {program_path}", 1106 enable=backend.LOG_BACKEND, 1107 ) 1108 ls_contents = "\n\t".join(os.listdir(dir_path)) 1109 dblog( 1110 f"Contents of {dir_path}:\n\t{ls_contents}", 1111 enable=backend.LOG_BACKEND, 1112 )
1120 @classmethod 1121 def instructions_as_code(cls, program, fn_defs): 1122 def indent(code, indent_amount): 1123 spaces = " " * (len(code) - len(code.lstrip())) 1124 spaces += " " * indent_amount 1125 return "\n".join([spaces + line for line in code.strip().split("\n")]) 1126 1127 in_binders_vars = [program.env[i] for i in program.in_binders] 1128 body_code_lines = [] 1129 for instruction in program.instructions: 1130 if len(instruction.out_binders) == 0: 1131 continue 1132 params = instruction.params.copy() 1133 for param_name, param in params.items(): 1134 if isinstance(param, Program): 1135 sub_program = param 1136 fn_defs = cls.instructions_as_code(sub_program, fn_defs) 1137 program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1138 params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]" 1139 if isinstance(param, DType): 1140 params[param_name] = f"slope.{param.name}" 1141 param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items()) 1142 in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs) 1143 out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders) 1144 sig = program.pprint_sig( 1145 [program.env[x].symval for x in instruction.inputs], 1146 [program.env[y].symval for y in instruction.out_binders], 1147 ) 1148 line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}""" 1149 body_code_lines += [indent(line, program.indent_amount)] 1150 1151 fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars]) 1152 # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args]) 1153 out_vars = [program.env[o] for o in program.outs] 1154 fn_sig = program.pprint_sig( 1155 [i.symval for i in in_binders_vars], 1156 [o.symval for o in out_vars], 1157 ) 1158 head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"] 1159 out_str = ", ".join([f"{o.name}" for o in out_vars]) 1160 tail_code_line = [indent(f"return {out_str}", program.indent_amount)] 1161 code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"] 1162 1163 fn_defs[program.name] = code_lines 1164 return fn_defs
1166 @staticmethod 1167 def prune_instructions(instructions, outs): 1168 graph = dict() 1169 for instruction in instructions: 1170 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1171 for parent in parent_nodes: 1172 if parent not in graph: 1173 graph[parent] = set() 1174 for child in child_nodes: 1175 graph[parent].add(child) 1176 visited_from_terminal = set() 1177 1178 def dfs(node, visited): 1179 visited.add(node) 1180 if node in graph: 1181 for neighbor in graph[node]: 1182 if neighbor not in visited: 1183 dfs(neighbor, visited) 1184 1185 for terminal_node in outs: 1186 dfs(terminal_node, visited_from_terminal) 1187 unreachable_nodes = set(graph.keys()) - visited_from_terminal 1188 1189 instructions_to_prune = [] 1190 for instruction in instructions: 1191 parent_nodes, child_nodes = instruction.out_binders, instruction.inputs 1192 if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes): 1193 instructions_to_prune += [instruction] 1194 new_instructions = [inst for inst in instructions if inst not in instructions_to_prune] 1195 if backend.LOG_PROGRAM: 1196 LI = len(instructions) 1197 LNI = len(new_instructions) 1198 DIFF = LI - LNI 1199 UN = len(unreachable_nodes) 1200 dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}") 1201 return new_instructions
1204class ProgramType(NamedTuple): 1205 in_types: Tuple[SymbolicTensor] 1206 out_types: Tuple[SymbolicTensor] 1207 1208 def __repr__(self): 1209 in_types = ", ".join(symval.str_short() for symval in self.in_types) 1210 out_types = ", ".join(symval.str_short() for symval in self.out_types) 1211 return f"({in_types}) -> ({out_types})"
ProgramType(in_types, out_types)
Create new instance of ProgramType(in_types, out_types)
Inherited Members
- builtins.tuple
- index
- count
1226class Store: 1227 val = empty 1228 1229 def set_value(self, val): 1230 assert self.val is empty 1231 self.val = val 1232 1233 def __call__(self): 1234 return self.val
NodeType(name, flatten, unflatten)
Create new instance of NodeType(name, flatten, unflatten)
Inherited Members
- builtins.tuple
- index
- count
1243class TreeDef(NamedTuple): 1244 node_type: NodeType 1245 node_metadata: Hashable 1246 child_treedefs: Tuple["TreeDef", ...] 1247 1248 def __repr__(self): 1249 ret = self.tree_repr(self) 1250 return ret 1251 1252 def tree_repr(self, tree, indent=" ", prefix="", last=True): 1253 ret = "" 1254 1255 def _tree_repr(tree, indent, prefix, last): 1256 nonlocal ret 1257 if isinstance(tree, TreeDef): 1258 ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n' 1259 for i, item in enumerate(tree.child_treedefs): 1260 new_prefix = prefix + (indent if not last else " ") 1261 new_last = i == len(tree.child_treedefs) - 1 1262 _tree_repr(item, indent, new_prefix, new_last) 1263 else: 1264 ret += f'{prefix} {("└─" if last else "├─")} {tree}\n' 1265 1266 _tree_repr(tree, indent=" ", prefix="", last=True) 1267 return ret 1268 1269 @property 1270 def num_leaves(self): 1271 def get_num_leaves(x): 1272 if isinstance(x, Leaf): 1273 return 1 1274 else: 1275 return sum(get_num_leaves(sub_x) for sub_x in x.child_treedefs) 1276 1277 return sum(get_num_leaves(x) for x in self.child_treedefs)
TreeDef(node_type, node_metadata, child_treedefs)
Create new instance of TreeDef(node_type, node_metadata, child_treedefs)
1252 def tree_repr(self, tree, indent=" ", prefix="", last=True): 1253 ret = "" 1254 1255 def _tree_repr(tree, indent, prefix, last): 1256 nonlocal ret 1257 if isinstance(tree, TreeDef): 1258 ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n' 1259 for i, item in enumerate(tree.child_treedefs): 1260 new_prefix = prefix + (indent if not last else " ") 1261 new_last = i == len(tree.child_treedefs) - 1 1262 _tree_repr(item, indent, new_prefix, new_last) 1263 else: 1264 ret += f'{prefix} {("└─" if last else "├─")} {tree}\n' 1265 1266 _tree_repr(tree, indent=" ", prefix="", last=True) 1267 return ret
Inherited Members
- builtins.tuple
- index
- count
1280class Leaf: 1281 def __init__(self, val): 1282 if hasattr(val, "shape"): 1283 val = SymbolicTensor.like(val) 1284 self.val = val 1285 1286 def __repr__(self): 1287 ret = self.val.str_short() if isinstance(self.val, SymbolicTensor) else repr(self.val) 1288 return f"<Leaf: {ret}>" 1289 1290 def __hash__(self): 1291 return hash(self.val) 1292 1293 def __eq__(self, other): 1294 return True # make TreeDef __eq__ don't care Leaf 1295 # if isinstance(other, Leaf): # TODO: test above assumption 1296 # return self.val == other.val
1304class JitOutput: 1305 def __init__(self, program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any]): 1306 super().__init__() 1307 self.program = program 1308 self.code = code 1309 self.codegen_output = codegen_output 1310 self.fn: Callable = fn 1311 self.consts = consts 1312 1313 def __call__(self, *args, **params): 1314 args, in_tree = tree_flatten(args) 1315 args = tree_map(lambda a: a.val if isinstance(a, Tensor) else a, args) 1316 try: 1317 outs = self.fn(*args, **params) 1318 if not isinstance(outs, tuple): # TODO: IREE FunctionInvoker destructure 1-tuple, need to undo 1319 outs = (outs,) 1320 except Exception as e: 1321 dblog(self.code, enable=backend.LOG_JIT) 1322 raise 1323 return [backend.tensor(TensorBuffer(o)) for o in outs]
1326class JitOp(MetaOperator): 1327 def meta_impl(self, *args, program: Program, **_): 1328 hashed_program = Hashed(program) 1329 num_consts = program.num_consts 1330 consts, args = args[:num_consts], args[num_consts:] 1331 hashed_consts = tuple(map(Hashed, consts)) 1332 jit_output = backend.jit_program(hashed_program, hashed_consts) 1333 ret = jit_output(*consts, *args) 1334 return ret 1335 1336 def reorg_args(self, args, params): 1337 return args, params 1338 1339 def typecheck(self, *in_types, program: Program): 1340 program_type = typecheck_program(program) 1341 if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)): 1342 ret = "Type mismatch program.in_types vs in_types:\n" 1343 for i, j in zip(program_type.in_types, in_types): 1344 ret += f"{i}, {j}, {i == j}" 1345 raise TypeError(ret) 1346 return program_type.out_types 1347 1348 def vmap(self, dim_size, vals_in, dims_in, program: Program): 1349 program, consts = vmap_program(program, dim_size, tuple(dims_in)) 1350 outs = self(*consts, *vals_in, program=program) 1351 if not isinstance(outs, tuple): 1352 outs = (outs,) 1353 return outs, [0] * len(outs) 1354 1355 def jvp(self, primals, tangents, *, program): 1356 new_program, new_consts = jvp_program(program) 1357 outs = bind( 1358 self, 1359 *new_consts, 1360 *primals, 1361 *tangents, 1362 program=new_program, 1363 ) 1364 n = len(outs) // 2 1365 primals_out, tangents_out = outs[:n], outs[n:] 1366 return primals_out, tangents_out 1367 1368 def T(self, cotangents, *invals, program): 1369 undef_primals = [isinstance(x, UndefinedPrimal) for x in invals] 1370 transposed_program, new_consts = transpose_program(program, tuple(undef_primals)) 1371 1372 residuals, _ = partition_list(undef_primals, invals) 1373 outs = bind( 1374 self, 1375 *new_consts, 1376 *residuals, 1377 *cotangents, 1378 program=transposed_program, 1379 ) 1380 outs = iter(outs) 1381 1382 return [next(outs) if undef else None for undef in undef_primals] 1383 1384 def partial_run(self, trace, tracers, *, program): 1385 in_unknowns = [not t.pval.is_known for t in tracers] 1386 program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns) 1387 known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) 1388 known_vals = [t.pval.const for t in known_tracers] 1389 outs1_res = bind(backend.jit_op, *known_vals, program=program1) 1390 outs1, res = split_list(outs1_res, len(program1.outs) - num_res) 1391 res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] 1392 outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs] 1393 instruction = InstructionDraft( 1394 self, 1395 res_tracers + unknown_tracers, 1396 dict(program=program2), 1397 [v.symval for v in program2.outs], 1398 list_map(weakref.ref, outs2), 1399 ) 1400 for t in outs2: 1401 t.draft = instruction 1402 1403 return merge_lists(out_unknowns, outs1, outs2) 1404 1405 def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]: 1406 program = instruction.params["program"] 1407 program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in) 1408 ins1, ins2 = partition_list(unks_in, instruction.inputs) 1409 out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders) 1410 res = [Var(v.symval) for v in program2.in_binders[:num_res]] 1411 instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res) 1412 instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2) 1413 return instruction1, instruction2, out_unknowns, res
1327 def meta_impl(self, *args, program: Program, **_): 1328 hashed_program = Hashed(program) 1329 num_consts = program.num_consts 1330 consts, args = args[:num_consts], args[num_consts:] 1331 hashed_consts = tuple(map(Hashed, consts)) 1332 jit_output = backend.jit_program(hashed_program, hashed_consts) 1333 ret = jit_output(*consts, *args) 1334 return ret
1339 def typecheck(self, *in_types, program: Program): 1340 program_type = typecheck_program(program) 1341 if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)): 1342 ret = "Type mismatch program.in_types vs in_types:\n" 1343 for i, j in zip(program_type.in_types, in_types): 1344 ret += f"{i}, {j}, {i == j}" 1345 raise TypeError(ret) 1346 return program_type.out_types
1355 def jvp(self, primals, tangents, *, program): 1356 new_program, new_consts = jvp_program(program) 1357 outs = bind( 1358 self, 1359 *new_consts, 1360 *primals, 1361 *tangents, 1362 program=new_program, 1363 ) 1364 n = len(outs) // 2 1365 primals_out, tangents_out = outs[:n], outs[n:] 1366 return primals_out, tangents_out
1368 def T(self, cotangents, *invals, program): 1369 undef_primals = [isinstance(x, UndefinedPrimal) for x in invals] 1370 transposed_program, new_consts = transpose_program(program, tuple(undef_primals)) 1371 1372 residuals, _ = partition_list(undef_primals, invals) 1373 outs = bind( 1374 self, 1375 *new_consts, 1376 *residuals, 1377 *cotangents, 1378 program=transposed_program, 1379 ) 1380 outs = iter(outs) 1381 1382 return [next(outs) if undef else None for undef in undef_primals]
1384 def partial_run(self, trace, tracers, *, program): 1385 in_unknowns = [not t.pval.is_known for t in tracers] 1386 program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns) 1387 known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) 1388 known_vals = [t.pval.const for t in known_tracers] 1389 outs1_res = bind(backend.jit_op, *known_vals, program=program1) 1390 outs1, res = split_list(outs1_res, len(program1.outs) - num_res) 1391 res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] 1392 outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs] 1393 instruction = InstructionDraft( 1394 self, 1395 res_tracers + unknown_tracers, 1396 dict(program=program2), 1397 [v.symval for v in program2.outs], 1398 list_map(weakref.ref, outs2), 1399 ) 1400 for t in outs2: 1401 t.draft = instruction 1402 1403 return merge_lists(out_unknowns, outs1, outs2)
1405 def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]: 1406 program = instruction.params["program"] 1407 program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in) 1408 ins1, ins2 = partition_list(unks_in, instruction.inputs) 1409 out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders) 1410 res = [Var(v.symval) for v in program2.in_binders[:num_res]] 1411 instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res) 1412 instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2) 1413 return instruction1, instruction2, out_unknowns, res
Inherited Members
1416class MainTrace(NamedTuple): 1417 level: int 1418 trace_type: Type["Trace"] 1419 global_data: Optional[Any]
MainTrace(level, trace_type, global_data)
Create new instance of MainTrace(level, trace_type, global_data)
Inherited Members
- builtins.tuple
- index
- count
1422class Trace: 1423 main: MainTrace 1424 1425 def __init__(self, main: MainTrace) -> None: 1426 self.main = main 1427 1428 def pure(self, val): 1429 raise NotImplementedError 1430 1431 def run_op(self, op, tracers, params): 1432 raise NotImplementedError
1435class RunTrace(Trace): 1436 pure = lambda self, x: x 1437 1438 def run_op(self, op: Operator, args, params): 1439 if isinstance(op, MetaOperator): 1440 args, params = op.reorg_args(args, params) 1441 args, params = op.args_fixer(*args, **params) 1442 ret = op.meta_impl(*args, **params) 1443 else: 1444 fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params) 1445 # with Timing(f"RUN {op}"):ret = jit( 1446 ret = jit( 1447 fn, 1448 static_argnames=("params",), 1449 name=jit.get_jit_name(args, params, op.name), 1450 )(*args, **params) 1451 1452 return ret 1453 1454 @staticmethod 1455 @lru_cache_verbose() 1456 def get_fn(op, *symval_args, **params): 1457 def fn(*args, **params): 1458 return [op(*args, **params)] 1459 1460 return fn
1438 def run_op(self, op: Operator, args, params): 1439 if isinstance(op, MetaOperator): 1440 args, params = op.reorg_args(args, params) 1441 args, params = op.args_fixer(*args, **params) 1442 ret = op.meta_impl(*args, **params) 1443 else: 1444 fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params) 1445 # with Timing(f"RUN {op}"):ret = jit( 1446 ret = jit( 1447 fn, 1448 static_argnames=("params",), 1449 name=jit.get_jit_name(args, params, op.name), 1450 )(*args, **params) 1451 1452 return ret
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
1463class SymbolicRunTrace(Trace): 1464 # pure = lambda self, x: x 1465 def pure(self, val: Any) -> SymbolicTensor: 1466 return val.symval 1467 1468 def run_op(self, op, tracers, params): 1469 symvals_in = tree_map(lambda x: x.symval, tracers) 1470 symvals_out = op.typecheck(*symvals_in, **params) 1471 return symvals_out
1474class TraceTensor(Tensor): 1475 PYTHON_TYPES = { 1476 bool, 1477 int, 1478 float, 1479 } 1480 _trace: "Trace" 1481 1482 def __init__(self, *args, **kwargs): 1483 raise NotImplementedError 1484 1485 symval = property(lambda self: get_symval(self.val)) 1486 dtype = property(lambda self: self.symval.dtype) 1487 shape = property(lambda self: self.symval.shape) 1488 device = property(lambda self: self.symval.device) 1489 1490 @property 1491 def val(self): 1492 raise NotImplementedError 1493 1494 def __str__(self): 1495 return repr(self) 1496 1497 def full_lower(self): 1498 return self 1499 1500 @property 1501 def ndim(self): 1502 return len(self.shape) 1503 1504 def __repr__(self): 1505 return f"{self.__class__.__name__}({repr(self.symval)})"
1508class VMapTraceTensor(TraceTensor): 1509 def __init__(self, trace, val, vmap_dim): 1510 self._trace = trace 1511 self._val = val 1512 self.vmap_dim = vmap_dim 1513 1514 @property 1515 def val(self): 1516 return self._val 1517 1518 @property 1519 def symval(self): 1520 symval = get_symval(self.val) 1521 if self.vmap_dim is None: 1522 return symval 1523 else: 1524 shape = list(symval.shape) 1525 del shape[self.vmap_dim] 1526 return symval.like(shape=tuple(shape)) 1527 1528 def full_lower(self): 1529 if self.vmap_dim is None: 1530 return full_lower(self.val) 1531 else: 1532 return self
Inherited Members
1535class VMapTrace(Trace): 1536 pure = lambda self, val: VMapTraceTensor(self, val, None) 1537 1538 @property 1539 def dim_size(self): 1540 return self.main.global_data 1541 1542 def run_op(self, op, tracers, params): 1543 vals_in, bdims_in = unzip2((t.val, t.vmap_dim) for t in tracers) 1544 val_outs, bdim_outs = op.vmap(self.dim_size, vals_in, bdims_in, **params) 1545 return [VMapTraceTensor(self, x, bd) for x, bd in list_zip(val_outs, bdim_outs)] 1546 1547 @staticmethod 1548 def move_vmap_dim(x, dim_size, src: int, dst: int): 1549 if src is None: # unsqueeze and expand 1550 target_shape = list(x.shape) 1551 target_shape.insert(dst, dim_size) 1552 unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))] 1553 x = x.reshape(tuple(unsqueeze_shape)) 1554 x = x.expand(tuple(target_shape)) 1555 return x 1556 elif src == dst: 1557 return x 1558 else: 1559 perm = [i for i in range(len(x.shape)) if i != src] 1560 perm.insert(dst, src) 1561 return x.permute(tuple(perm))
1547 @staticmethod 1548 def move_vmap_dim(x, dim_size, src: int, dst: int): 1549 if src is None: # unsqueeze and expand 1550 target_shape = list(x.shape) 1551 target_shape.insert(dst, dim_size) 1552 unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))] 1553 x = x.reshape(tuple(unsqueeze_shape)) 1554 x = x.expand(tuple(target_shape)) 1555 return x 1556 elif src == dst: 1557 return x 1558 else: 1559 perm = [i for i in range(len(x.shape)) if i != src] 1560 perm.insert(dst, src) 1561 return x.permute(tuple(perm))
1564class JVPTraceTensor(TraceTensor): 1565 def __init__(self, trace, primal, tangent): 1566 self._trace = trace 1567 self.primal = primal 1568 self.tangent = tangent 1569 1570 @property 1571 def symval(self): 1572 return get_symval(self.primal) 1573 1574 @property 1575 def val(self): 1576 return self.primal 1577 1578 @property 1579 def dtype(self): 1580 return self.primal.dtype
Inherited Members
1583class JVPTrace(Trace): 1584 def pure(self, val): 1585 if isinstance(val, PartialRunTrace): 1586 val = val.pval.const 1587 return JVPTraceTensor(self, val, backend.zeros_like(val)) 1588 1589 def run_op(self, op, tracers, params): 1590 primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) 1591 primals_out, tangents_out = op.jvp(primals_in, tangents_in, **params) 1592 return [JVPTraceTensor(self, x, t) for x, t in list_zip(primals_out, tangents_out)]
1595class ProgramTraceTensor(TraceTensor): 1596 __slots__ = ["symval"] 1597 symval: SymbolicTensor 1598 1599 def __init__(self, trace, symval): 1600 self._trace = trace 1601 self.symval = symval
Inherited Members
1604class ProgramTrace(Trace): 1605 @property 1606 def builder(self): 1607 return self.main.global_data 1608 1609 def new_arg(self, symval) -> ProgramTraceTensor: 1610 symval = SymbolicTensor.like(symval) 1611 tracer = self.builder.new_tracer(self, symval) 1612 self.builder.tracer_to_var[id(tracer)] = Var(symval) 1613 1614 return tracer 1615 1616 def pure(self, val: Any) -> ProgramTraceTensor: 1617 # get_or_make_const_tracer 1618 tracer = self.builder.const_tracers.get(id(val)) 1619 if tracer is None: 1620 tracer = self.builder.new_tracer(self, get_symval(val)) 1621 self.builder.add_const(tracer, val) 1622 # print(self.builder.const_tracers) 1623 return tracer 1624 1625 def run_op(self, op, tracers, params): 1626 symvals_in = tree_map(lambda x: x.symval, tracers) 1627 symvals_out = op.typecheck(*symvals_in, **params) 1628 1629 out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out] 1630 inputs = [self.builder.getvar(t) for t in tracers] 1631 outvars = [self.builder.add_var(t) for t in out_tracers] 1632 1633 self.builder.add_instruction(Instruction(op, inputs, params, outvars)) 1634 return out_tracers
1616 def pure(self, val: Any) -> ProgramTraceTensor: 1617 # get_or_make_const_tracer 1618 tracer = self.builder.const_tracers.get(id(val)) 1619 if tracer is None: 1620 tracer = self.builder.new_tracer(self, get_symval(val)) 1621 self.builder.add_const(tracer, val) 1622 # print(self.builder.const_tracers) 1623 return tracer
1625 def run_op(self, op, tracers, params): 1626 symvals_in = tree_map(lambda x: x.symval, tracers) 1627 symvals_out = op.typecheck(*symvals_in, **params) 1628 1629 out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out] 1630 inputs = [self.builder.getvar(t) for t in tracers] 1631 outvars = [self.builder.add_var(t) for t in out_tracers] 1632 1633 self.builder.add_instruction(Instruction(op, inputs, params, outvars)) 1634 return out_tracers
1637class ProgramBuilder: 1638 instructions: List[Instruction] 1639 tracer_to_var: Dict[int, Var] 1640 const_tracers: Dict[int, TraceTensor] 1641 constvals: Dict[Var, Any] 1642 tracers: List[ProgramTraceTensor] 1643 1644 def __init__(self): 1645 self.instructions = [] 1646 self.tracer_to_var = {} 1647 self.const_tracers = {} 1648 self.constvals = {} 1649 self.tracers = [] 1650 1651 def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor: 1652 tracer = ProgramTraceTensor(trace, symval) 1653 self.tracers += [tracer] 1654 return tracer 1655 1656 def add_instruction(self, instruction: Instruction) -> None: 1657 self.instructions += [instruction] 1658 1659 def add_var(self, tracer: ProgramTraceTensor) -> Var: 1660 assert id(tracer) not in self.tracer_to_var 1661 var = self.tracer_to_var[id(tracer)] = Var(tracer.symval) 1662 return var 1663 1664 def getvar(self, tracer: ProgramTraceTensor) -> Var: 1665 var = self.tracer_to_var.get(id(tracer)) 1666 assert var is not None 1667 return var 1668 1669 def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var: 1670 var = self.add_var(tracer) 1671 self.const_tracers[id(val)] = tracer 1672 self.constvals[var] = val 1673 return var 1674 1675 def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]: 1676 constvars, constvals = unzip2(self.constvals.items()) 1677 t2v = lambda t: self.tracer_to_var[id(t)] 1678 in_binders = constvars + [t2v(t) for t in in_tracers] 1679 out_vars = [t2v(t) for t in out_tracers] 1680 program = Program( 1681 in_binders, 1682 self.instructions, 1683 out_vars, 1684 len(constvals), 1685 static_args, 1686 name, 1687 ) 1688 typecheck_program(program) 1689 program, constvals = self._inline_literals(program, constvals) 1690 typecheck_program(program) 1691 # dblog(program, enable=backend.LOG_PROGRAM) 1692 return program, constvals 1693 1694 def _inline_literals(self, program: Program, consts: List[Any]) -> Tuple[Program, List[Any]]: 1695 const_binders, other_binders = split_list(program.in_binders, len(consts)) 1696 scalars = [type(x) in TraceTensor.PYTHON_TYPES and not get_symval(x).shape for x in consts] 1697 new_const_binders, lit_binders = partition_list(scalars, const_binders) 1698 new_consts, lit_vals = partition_list(scalars, consts) 1699 literals = dict(list_zip(lit_binders, list_map(Lit, lit_vals))) 1700 new_outs = [literals.get(x, x) for x in program.outs] 1701 new_instructions = [ 1702 Instruction( 1703 instruction.op, 1704 [literals.get(x, x) for x in instruction.inputs], 1705 instruction.params, 1706 instruction.out_binders, 1707 ) 1708 for instruction in program.instructions 1709 ] 1710 new_program = Program( 1711 new_const_binders + other_binders, 1712 new_instructions, 1713 new_outs, 1714 len(new_consts), 1715 program.static_args, 1716 program.name, 1717 ) 1718 return new_program, tuple(new_consts) 1719 1720 def get_current_scope_info(self): 1721 current_frame = inspect.currentframe() 1722 current_function_name = current_frame.f_code.co_name 1723 current_module_name = inspect.getmodulename(current_frame.f_code.co_filename) 1724 current_class_name = None 1725 for frame_info in inspect.getouterframes(current_frame): 1726 print(frame_info) 1727 frame_locals = frame_info.frame.f_locals 1728 print(frame_locals) 1729 if "self" in frame_locals: 1730 current_class_name = frame_locals["self"].__class__.__name__ 1731 break 1732 return { 1733 "Function": current_function_name, 1734 "Module": current_module_name, 1735 "Class": current_class_name, 1736 }
1675 def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]: 1676 constvars, constvals = unzip2(self.constvals.items()) 1677 t2v = lambda t: self.tracer_to_var[id(t)] 1678 in_binders = constvars + [t2v(t) for t in in_tracers] 1679 out_vars = [t2v(t) for t in out_tracers] 1680 program = Program( 1681 in_binders, 1682 self.instructions, 1683 out_vars, 1684 len(constvals), 1685 static_args, 1686 name, 1687 ) 1688 typecheck_program(program) 1689 program, constvals = self._inline_literals(program, constvals) 1690 typecheck_program(program) 1691 # dblog(program, enable=backend.LOG_PROGRAM) 1692 return program, constvals
1720 def get_current_scope_info(self): 1721 current_frame = inspect.currentframe() 1722 current_function_name = current_frame.f_code.co_name 1723 current_module_name = inspect.getmodulename(current_frame.f_code.co_filename) 1724 current_class_name = None 1725 for frame_info in inspect.getouterframes(current_frame): 1726 print(frame_info) 1727 frame_locals = frame_info.frame.f_locals 1728 print(frame_locals) 1729 if "self" in frame_locals: 1730 current_class_name = frame_locals["self"].__class__.__name__ 1731 break 1732 return { 1733 "Function": current_function_name, 1734 "Module": current_module_name, 1735 "Class": current_class_name, 1736 }
1739class UndefinedPrimal(NamedTuple): 1740 symval: SymbolicTensor 1741 1742 @property 1743 def shape(self): 1744 return self.symval.shape 1745 1746 @property 1747 def dtype(self): 1748 return self.symval.dtype 1749 1750 @property 1751 def device(self): 1752 return self.symval.device 1753 1754 @property 1755 def ndim(self): 1756 return self.symval.ndim 1757 1758 def __repr__(self): 1759 return f"<UndefinedPrimal: symval={self.symval}>" 1760 1761 str_short = __repr__
UndefinedPrimal(symval,)
Inherited Members
- builtins.tuple
- index
- count
1764class PartialValue(NamedTuple): 1765 symval: SymbolicTensor 1766 const: Optional[Any] 1767 1768 is_known = property(lambda self: self.const is not None) 1769 is_unknown = property(lambda self: self.const is None)
PartialValue(symval, const)
Create new instance of PartialValue(symval, const)
Inherited Members
- builtins.tuple
- index
- count
LambdaBindingDraft()
Inherited Members
- builtins.tuple
- index
- count
ConstDraft(val,)
Inherited Members
- builtins.tuple
- index
- count
1780class InstructionDraft(NamedTuple): 1781 prim: Operator 1782 tracers_in: List["PartialRunTraceTensor"] 1783 params: Dict[str, Any] 1784 symvals_out: List[SymbolicTensor] 1785 tracer_refs_out: List[weakref.ReferenceType["PartialRunTraceTensor"]]
InstructionDraft(prim, tracers_in, params, symvals_out, tracer_refs_out)
Create new instance of InstructionDraft(prim, tracers_in, params, symvals_out, tracer_refs_out)
Inherited Members
- builtins.tuple
- index
- count
1791class PartialRunTraceTensor(TraceTensor): 1792 def __init__(self, trace, pval, draft): 1793 self._trace = trace 1794 self.pval = pval 1795 self.draft = draft 1796 1797 symval = property(lambda self: self.pval.symval) 1798 val = property(lambda self: self.pval.const) 1799 1800 def full_lower(self): 1801 if self.pval.is_known: 1802 return full_lower(self.pval.const) 1803 return self
Inherited Members
1806class PartialRunTrace(Trace): 1807 def new_arg(self, pval: PartialValue) -> Any: 1808 return PartialRunTraceTensor(self, pval, LambdaBindingDraft()) 1809 1810 def pure(self, val: Any) -> PartialRunTraceTensor: 1811 return PartialRunTraceTensor(self, make_known_pval(val), None) 1812 1813 def instantiate_const(self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor: 1814 if tracer.pval.is_unknown: 1815 return tracer 1816 else: 1817 pval = make_unknown_pval(SymbolicTensor.like(tracer.symval)) 1818 return PartialRunTraceTensor(self, pval, ConstDraft(tracer.pval.const)) 1819 1820 def run_op(self, op, tracers, params): 1821 is_knowns = tuple(t.pval.is_known for t in tracers) 1822 1823 if all(is_knowns): 1824 return bind(op, *list_map(full_lower, tracers), **params) 1825 return op.partial_run(self, tracers, **params)
1841def set_backend(name, where="slope.backends"): 1842 global backend 1843 backend = importlib.import_module(f"{where}.{name}").backend 1844 import slope.nn as nn 1845 1846 # backend.register_node(nn.Module, nn.Module.flatten, nn.Module.unflatten, "Module") 1847 1848 dblog(f"slope backend is {backend}", enable=backend.LOG_INIT)
1879def tree_flatten(x: Any) -> Any: 1880 def _tree_flatten(x_: Any) -> Tuple[Iterable, Union[TreeDef, Leaf]]: 1881 node_type = None 1882 for k in backend.node_types.keys(): 1883 if isinstance(x_, k): 1884 node_type = backend.node_types[k] 1885 1886 if node_type is not None: 1887 node_metadata, children = node_type.flatten(x_) 1888 children_flat, child_trees = unzip2(list_map(_tree_flatten, children)) 1889 children_iter = itertools.chain.from_iterable(children_flat) 1890 treedef = TreeDef(node_type, node_metadata, tuple(child_trees)) 1891 return children_iter, treedef 1892 else: 1893 return (x_,), Leaf(x_) 1894 1895 children_iter, treedef = _tree_flatten(x) 1896 return tuple(children_iter), treedef
1899def tree_unflatten(treedef: TreeDef, xs: Tuple[Any]) -> Any: 1900 def _tree_unflatten(treedef_: TreeDef, xs_: Iterator) -> Any: 1901 if isinstance(treedef_, Leaf): 1902 dblog(f" tree leaf found: {xs_}\n", enable=backend.LOG_TREE) 1903 return next(xs_) 1904 else: 1905 dblog(f" now\n {treedef_}", enable=backend.LOG_TREE) 1906 children = (_tree_unflatten(t, xs_) for t in treedef_.child_treedefs) 1907 dblog(f"{children=}\n", enable=backend.LOG_TREE) 1908 return treedef_.node_type.unflatten(treedef_.node_metadata, children) 1909 1910 dblog(f"unflattening {treedef}", enable=backend.LOG_TREE) 1911 return _tree_unflatten(treedef, iter(xs)) 1912 # with Timing(f"\nTREE:\n{treedef}"): 1913 # ret = _tree_unflatten(treedef, iter(xs)) 1914 # return ret
1917def tree_transpose( 1918 outer_treedef: TreeDef, 1919 inner_treedef: TreeDef, 1920 tree_to_transpose: Any, 1921) -> Any: 1922 flat, treedef = tree_flatten(tree_to_transpose) 1923 inner_size = inner_treedef.num_leaves 1924 outer_size = outer_treedef.num_leaves 1925 if treedef.num_leaves != (inner_size * outer_size): 1926 raise TypeError 1927 iter_flat = iter(flat) 1928 lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)] 1929 permuted_lol = zip(*lol) 1930 subtrees = map(partial(tree_unflatten, outer_treedef), permuted_lol) 1931 return tree_unflatten(inner_treedef, subtrees)
1934def flatten_fn(f, in_tree, *, has_aux=False): 1935 store = Store() 1936 1937 def flat_fn(*args_flat, **params): 1938 tree_args = tree_unflatten(in_tree, args_flat) 1939 out = f(*tree_args, **params) 1940 if has_aux: 1941 out, aux = out 1942 out_flat, out_tree = tree_flatten(out) 1943 store.set_value(out_tree) 1944 return (out_flat, aux) if has_aux else out_flat 1945 1946 return flat_fn, store
1949def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any: 1950 leaves, treedef = tree_flatten(tree) 1951 if len(rest) == 0: 1952 out_tree_flat = tuple(f(leaf) for leaf in leaves) 1953 out_tree = tree_unflatten(treedef, out_tree_flat) 1954 else: 1955 all_leaves = [leaves] 1956 for t in rest: 1957 t_leaves, t_treedef = tree_flatten(t) 1958 assert t_treedef == treedef 1959 all_leaves += [t_leaves] 1960 1961 out_tree_flat = tuple(f(*xs) for xs in zip(*all_leaves)) 1962 out_tree = tree_unflatten(treedef, out_tree_flat) 1963 ret = out_tree 1964 if out_leaf: 1965 ret = (ret, tree_flatten(out_tree_flat[0])) 1966 return ret
1990def find_top_trace(xs) -> Trace: 1991 arrs = [] 1992 1993 def get_arr_from_seq(seq): 1994 nonlocal arrs 1995 for x in seq: 1996 if type(x) in (tuple, list): 1997 get_arr_from_seq(x) 1998 elif isinstance(x, TraceTensor): 1999 arrs += [x] 2000 2001 get_arr_from_seq(xs) 2002 arrs = tuple(arrs) 2003 top_main = max( 2004 (x._trace.main for x in arrs), 2005 default=trace_stack[0], 2006 key=operator_py.attrgetter("level"), 2007 ) 2008 if stashed_trace and stashed_trace.level > top_main.level: 2009 top_main = stashed_trace 2010 return top_main.trace_type(top_main)
2013def full_raise(trace: Trace, val: Any) -> TraceTensor: 2014 if not isinstance(val, TraceTensor): 2015 return trace.pure(val) 2016 level = trace.main.level 2017 if val._trace.main is trace.main: 2018 return val 2019 elif val._trace.main.level < level: 2020 return trace.pure(val) 2021 elif val._trace.main.level > level: 2022 raise Exception(f"Can't lift level {val._trace.main.level} to {level}.") 2023 else: 2024 raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
2036def typecheck_program(program: Program) -> ProgramType: 2037 env: Set[Var] = set() 2038 2039 for v in program.in_binders: 2040 if v in env: 2041 raise TypeError 2042 env.add(v) 2043 2044 for instruction in program.instructions: 2045 in_types = [typecheck_atom(env, x) for x in instruction.inputs] 2046 out_types = instruction.op.typecheck(*in_types, **instruction.params) 2047 for out_binder, out_type in list_zip(instruction.out_binders, out_types): 2048 if not out_type == out_binder.symval: 2049 raise TypeError 2050 for out_binder in instruction.out_binders: 2051 if out_binder in env: 2052 raise TypeError 2053 env.add(out_binder) 2054 2055 in_types = [v.symval for v in program.in_binders] 2056 out_types = [typecheck_atom(env, x) for x in program.outs] 2057 return ProgramType(tuple(in_types), tuple(out_types))
2071def run_program(program: Program, args: List[Any]) -> List[Any]: 2072 env: Dict[Var, Any] = {} 2073 2074 def read(x: Atom) -> Any: 2075 return env[x] if type(x) is Var else x.val 2076 2077 def write(v: Var, val: Any) -> None: 2078 assert v not in env # single-assignment 2079 env[v] = val 2080 2081 list_map(write, program.in_binders, args) 2082 for instruction in program.instructions: 2083 in_vals = list_map(read, instruction.inputs) 2084 outs = bind(instruction.op, *in_vals, **instruction.params) 2085 list_map(write, instruction.out_binders, outs) 2086 return list_map(read, program.outs)
2093def vmap_flat(f, in_dim, out_dim, dim_size, *args): 2094 if dim_size is None: 2095 dims = set([x.shape[d] for x, d in list_zip(args, in_dim) if d is not None]) 2096 assert len(dims) == 1 2097 (dim_size,) = dims 2098 with new_main_trace(VMapTrace, dim_size) as main: 2099 trace = VMapTrace(main) 2100 tracers_in = [VMapTraceTensor(trace, x, dim) if dim is not None else x for x, dim in list_zip(args, in_dim)] 2101 outs = f(*tracers_in) 2102 tracers_out = [full_raise(trace, out) for out in outs] 2103 vals_out, y_vmap_dims = unzip2((t.val, t.vmap_dim) for t in tracers_out) 2104 ret = [VMapTrace.move_vmap_dim(val_out, dim_size, bdim, out_dim) for val_out, bdim, out_dim in zip(vals_out, y_vmap_dims, out_dim)] 2105 return ret
2108def vmap(f, in_dim=0, out_dim=0, dim_size=None): 2109 def batched_f(*args): 2110 nonlocal in_dim, out_dim, dim_size 2111 args_flat, in_tree = tree_flatten(args) 2112 in_dim = (in_dim,) * len(args) if isinstance(in_dim, int) else in_dim 2113 out_dim = (out_dim,) * len(args) if isinstance(out_dim, int) else out_dim 2114 in_dim_flat, in_dim_tree = tree_flatten(in_dim) 2115 out_dim_flat, out_dim_tree = tree_flatten(out_dim) 2116 if not (in_tree == in_dim_tree == out_dim_tree): 2117 raise TypeError(f"\n{in_tree}\n!=\n{in_dim_tree}!=\n{out_dim_tree}") 2118 f_flat, out_tree_store = flatten_fn(f, in_tree) 2119 # if len(args_flat) > len(in_dim_flat): 2120 # in_dim_flat = (in_dim[0],) * len(args_flat) 2121 outs_flat = vmap_flat(f_flat, in_dim_flat, out_dim_flat, dim_size, *args_flat) 2122 return tree_unflatten(out_tree_store(), outs_flat) 2123 2124 return batched_f
2127def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args): 2128 with new_main_trace(JVPTrace, global_data) as main: 2129 trace = JVPTrace(main) 2130 tracers_in = [JVPTraceTensor(trace, x, t) for x, t in list_zip(primals, tangents)] 2131 jvp_flat_ret = f(*tracers_in, **static_args) 2132 if has_aux: 2133 (outs, aux) = jvp_flat_ret 2134 # aux_ = aux 2135 aux = tree_map(lambda x: x.primal, aux) 2136 # aux = tree_map(lambda x: x.full_lower(), aux) 2137 # 2138 else: 2139 outs = jvp_flat_ret 2140 tracers_out = [full_raise(trace, out) for out in outs] 2141 primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out) 2142 return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
2145def jvp(f, primals, tangents, *, has_aux=False, global_data=None, **static_args): 2146 primals_flat, in_tree = tree_flatten(primals) 2147 tangents_flat, in_tree2 = tree_flatten(tangents) 2148 for p, t in zip(primals_flat, tangents_flat): 2149 assert p.shape == t.shape, f"{p.shape=} != {t.shape=}" 2150 assert p.dtype == t.dtype, f"{p.dtype=} != {t.dtype=}" 2151 assert p.device == t.device, f"{p.device=} != {t.device=}" 2152 if in_tree != in_tree2: 2153 raise TypeError 2154 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2155 jvp_ret = jvp_flat( 2156 f, 2157 primals_flat, 2158 tangents_flat, 2159 has_aux=has_aux, 2160 global_data=global_data, 2161 **static_args, 2162 ) 2163 if has_aux: 2164 (primals_out_flat, tangents_out_flat), aux = jvp_ret 2165 else: 2166 (primals_out_flat, tangents_out_flat) = jvp_ret 2167 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2168 tangents_out = tree_unflatten(out_tree_store(), tangents_out_flat) 2169 return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
2201@contextmanager 2202def symbolic_run(): 2203 global trace_stack 2204 level = len(trace_stack) 2205 main = MainTrace(level, SymbolicRunTrace, global_data=None) 2206 trace_stack += [main] 2207 global stashed_trace 2208 prev_stashed_trace, stashed_trace = stashed_trace, main 2209 try: 2210 yield 2211 finally: 2212 stashed_trace = prev_stashed_trace 2213 trace_stack.pop()
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
2271def partial_run_flat( 2272 f: Callable, pvals_in: List["PartialValue"], has_aux, global_data=None 2273) -> Tuple[Program, List["PartialValue"], List[Any]]: 2274 with new_main_trace(PartialRunTrace, global_data) as main: 2275 trace = PartialRunTrace(main) 2276 tracers_in = [trace.new_arg(pval) for pval in pvals_in] 2277 outs = f(*tracers_in) 2278 if has_aux: 2279 outs, aux = outs 2280 tracers_out = [full_raise(trace, out) for out in outs] 2281 pvals_out = [t.pval for t in tracers_out] 2282 unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown] 2283 unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown] 2284 program, consts = tracers_to_program(unk_tracers_in, unk_tracers_out) 2285 2286 return (program, pvals_out, consts, aux) if has_aux else (program, pvals_out, consts)
2289def partial_run_program( 2290 program: Program, 2291 in_unknowns: List[bool], 2292 instantiate: Optional[List[bool]] = None, 2293) -> Tuple[Program, Program, List[bool], int]: 2294 env: Dict[Var, bool] = {} 2295 residuals: Set[Var] = set() 2296 2297 def read(x: Atom) -> bool: 2298 return type(x) is Var and env[x] 2299 2300 def write(unk: bool, v: Var) -> None: 2301 env[v] = unk 2302 2303 instructions1, instructions2 = [], [] 2304 list_map(write, in_unknowns, program.in_binders) 2305 2306 for instruction in program.instructions: 2307 unks_in = list_map(read, instruction.inputs) 2308 ( 2309 instruction1, 2310 instruction2, 2311 unks_out, 2312 res, 2313 ) = instruction.op.partial_run_instruction(unks_in, instruction) 2314 if instruction1 is not None: 2315 instructions1 += [instruction1] 2316 if instruction2 is not None: 2317 instructions2 += [instruction2] 2318 if res is not None: 2319 residuals.update(res) 2320 list_map(write, unks_out, instruction.out_binders) 2321 2322 out_unknowns = list_map(read, program.outs) 2323 if instantiate is not None: 2324 for v, uk, inst in zip(program.outs, out_unknowns, instantiate): 2325 if inst and not uk: 2326 if type(v) is Var: 2327 residuals.add(v) 2328 out_unknowns = list_map(operator_py.or_, out_unknowns, instantiate) 2329 2330 residuals, num_res = list(residuals), len(residuals) 2331 assert all(type(v) is Var for v in residuals), residuals 2332 2333 ins1, ins2 = partition_list(in_unknowns, program.in_binders) 2334 outs1, outs2 = partition_list(out_unknowns, program.outs) 2335 2336 program1 = Program( 2337 ins1, 2338 instructions1, 2339 outs1 + residuals, 2340 0, 2341 program.static_args, 2342 f"{program.name}_partial1", 2343 ) 2344 program2 = Program( 2345 residuals + ins2, 2346 instructions2, 2347 outs2, 2348 0, 2349 program.static_args, 2350 f"{program.name}_partial2", 2351 ) 2352 typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2) 2353 2354 return program1, program2, out_unknowns, num_res
2357def typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2): 2358 programty = typecheck_program(program) # (a1, a2) -> (b1, b2 ) 2359 program1ty = typecheck_program(program1) # a1 -> (b1, res) 2360 program2ty = typecheck_program(program2) # (res, a2) -> b2 2361 2362 a1, a2 = partition_list(in_unknowns, programty.in_types) 2363 b1, b2 = partition_list(out_unknowns, programty.out_types) 2364 b1_, res = split_list(program1ty.out_types, len(b1)) 2365 res_, a2_ = split_list(program2ty.in_types, len(res)) 2366 b2_ = program2ty.out_types 2367 2368 a1 = tuple(a1) 2369 a2, a2_ = tuple(a2), tuple(a2_) 2370 b1, b1_ = tuple(b1), tuple(b1_) 2371 b2, b2_ = tuple(b2), tuple(b2_) 2372 res, res_ = tuple(res), tuple(res_) 2373 2374 if program1ty.in_types != a1: 2375 raise TypeError 2376 if program2ty.out_types != b2: 2377 raise TypeError 2378 if b1 != b1_: 2379 raise TypeError 2380 if res != res_: 2381 raise TypeError 2382 if a2 != a2_: 2383 raise TypeError 2384 if b2 != b2_: 2385 raise TypeError
2388def linearize_flat(f, *primals_in, has_aux): 2389 pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in] 2390 2391 def f_jvp(*primals_tangents_in): 2392 jvp_ret = jvp(f, *split_half(primals_tangents_in), has_aux=has_aux) 2393 if has_aux: 2394 (primals_out, tangents_out), aux = jvp_ret 2395 return ((*primals_out, *tangents_out), aux) 2396 else: 2397 primals_out, tangents_out = jvp_ret 2398 return (*primals_out, *tangents_out) 2399 2400 partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux) 2401 if has_aux: 2402 program, pvals_out, consts, aux = partial_run_flat_ret 2403 else: 2404 program, pvals_out, consts = partial_run_flat_ret 2405 primal_pvals, _ = split_half(pvals_out) 2406 assert all(pval.is_known for pval in primal_pvals) 2407 primals_out = [pval.const for pval in primal_pvals] 2408 f_lin = lambda *tangents: run_program(program, [*consts, *tangents]) 2409 return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
2412def linearize(f, *primals_in, has_aux=False): 2413 primals_in_flat, in_tree = tree_flatten(primals_in) 2414 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2415 linearize_flat_ret = linearize_flat(f, *primals_in_flat, has_aux=has_aux) 2416 if has_aux: 2417 primals_out_flat, f_lin_flat, aux = linearize_flat_ret 2418 else: 2419 primals_out_flat, f_lin_flat = linearize_flat_ret 2420 2421 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2422 2423 def f_lin(*tangents_in): 2424 tangents_in_flat, in_tree2 = tree_flatten(tangents_in) 2425 if in_tree != in_tree2: 2426 raise TypeError 2427 tangents_out_flat = f_lin_flat(*tangents_in_flat) 2428 return tree_unflatten(out_tree_store(), tangents_out_flat) 2429 2430 return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
2433def tracers_to_program( 2434 tracers_in: List["PartialRunTraceTensor"], 2435 tracers_out: List["PartialRunTraceTensor"], 2436): 2437 def tracer_parents(t: PartialRunTraceTensor) -> List[PartialRunTraceTensor]: 2438 return t.draft.tracers_in if isinstance(t.draft, InstructionDraft) else [] 2439 2440 def draft_to_instruction(tracer_to_var: Dict[int, Var], draft: InstructionDraft) -> Instruction: 2441 inputs = [tracer_to_var[id(t)] for t in draft.tracers_in] 2442 out_binders = [Var(symval) for symval in draft.symvals_out] 2443 for t_ref, var in list_zip(draft.tracer_refs_out, out_binders): 2444 if t_ref() is not None: 2445 tracer_to_var[id(t_ref())] = var 2446 return Instruction(draft.prim, inputs, draft.params, out_binders) 2447 2448 tracer_to_var: Dict[int, Var] = {id(t): Var(SymbolicTensor.like(t.symval)) for t in tracers_in} 2449 constvar_to_val: Dict[int, Any] = {} 2450 constid_to_var: Dict[int, Var] = {} 2451 processed_instructions: Set[int] = set() 2452 instructions: List[Instruction] = [] 2453 for t in toposort(tracers_out, tracer_parents): 2454 if isinstance(t.draft, LambdaBindingDraft): 2455 assert id(t) in set(list_map(id, tracers_in)) 2456 elif isinstance(t.draft, ConstDraft): 2457 val = t.draft.val 2458 var = constid_to_var.get(id(val)) 2459 if var is None: 2460 symval = SymbolicTensor.like(get_symval(val)) 2461 var = constid_to_var[id(val)] = Var(symval) 2462 constvar_to_val[var] = val 2463 tracer_to_var[id(t)] = var 2464 elif isinstance(t.draft, InstructionDraft): 2465 if id(t.draft) not in processed_instructions: 2466 instructions += [draft_to_instruction(tracer_to_var, t.draft)] 2467 processed_instructions.add(id(t.draft)) 2468 else: 2469 raise TypeError(t.draft) 2470 2471 constvars, constvals = unzip2(constvar_to_val.items()) 2472 in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in] 2473 # out_vars = [tracer_to_var[id(t)] for t in tracers_out if id(t) in tracer_to_var] 2474 out_vars = [tracer_to_var[id(t)] for t in tracers_out] 2475 program = Program(tuple(in_binders), tuple(instructions), tuple(out_vars)) 2476 typecheck_program(program) 2477 return program, constvals
2480def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]): 2481 def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]): 2482 seen = set() 2483 for node in nodes: 2484 assert all(id(parent) in seen for parent in parents(node)) 2485 seen.add(id(node)) 2486 2487 def remove_duplicates(lst): 2488 seen = set() 2489 return [x for x in lst if id(x) not in seen and not seen.add(id(x))] 2490 2491 if not out_nodes: 2492 return [] 2493 out_nodes = remove_duplicates(out_nodes) 2494 2495 child_counts = {} 2496 stack = list(out_nodes) 2497 while stack: 2498 node = stack.pop() 2499 if id(node) in child_counts: 2500 child_counts[id(node)] += 1 2501 else: 2502 child_counts[id(node)] = 1 2503 stack.extend(parents(node)) 2504 for node in out_nodes: 2505 child_counts[id(node)] -= 1 2506 2507 sorted_nodes = [] 2508 childless_nodes = [node for node in out_nodes if not child_counts[id(node)]] 2509 while childless_nodes: 2510 node = childless_nodes.pop() 2511 sorted_nodes += [node] 2512 for parent in parents(node): 2513 if child_counts[id(parent)] == 1: 2514 childless_nodes += [parent] 2515 else: 2516 child_counts[id(parent)] -= 1 2517 2518 sorted_nodes = sorted_nodes[::-1] 2519 check_toposort(sorted_nodes, parents) 2520 return sorted_nodes
2523def vjp_flat(f, *primals_in, has_aux=False, **static_args): 2524 pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in] 2525 primal_pvals_in, tangent_pvals_in = split_half(pvals_in) 2526 del primal_pvals_in 2527 2528 def f_jvp(*primals_tangents_in): 2529 jvp_ret = jvp( 2530 f, 2531 *split_half(primals_tangents_in), 2532 has_aux=has_aux, 2533 global_data="vjp", 2534 **static_args, 2535 ) 2536 if has_aux: 2537 ((primals_out, tangents_out), aux) = jvp_ret 2538 else: 2539 (primals_out, tangents_out) = jvp_ret 2540 return ([*primals_out, *tangents_out], aux) if has_aux else [*primals_out, *tangents_out] 2541 2542 partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux, "vjp") 2543 if has_aux: 2544 program, pvals_out, consts, aux = partial_run_flat_ret 2545 else: 2546 program, pvals_out, consts = partial_run_flat_ret 2547 2548 primal_pvals, tangent_pvals = split_half(pvals_out) 2549 del tangent_pvals 2550 assert all(pval.is_known for pval in primal_pvals) 2551 primals_out_flat = [pval.const for pval in primal_pvals] 2552 transpose_inputs = consts + [UndefinedPrimal(t.symval) for t in tangent_pvals_in] 2553 2554 def f_vjp_flat(*cotangents): 2555 # return backward_pass(program, transpose_inputs, cotangents) 2556 undef_primals = tuple(isinstance(x, UndefinedPrimal) for x in transpose_inputs) 2557 transposed_program, new_consts = transpose_program(program, undef_primals) 2558 residuals, _ = partition_list(undef_primals, transpose_inputs) 2559 outs = run_program(transposed_program, (*new_consts, *residuals, *cotangents)) 2560 return outs 2561 2562 return (primals_out_flat, f_vjp_flat, aux) if has_aux else (primals_out_flat, f_vjp_flat)
2565def vjp(f, *primals_in, has_aux=False, **static_args): 2566 primals_in_flat, in_tree = tree_flatten(primals_in) 2567 f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux) 2568 vjp_ret = vjp_flat(f, *primals_in_flat, has_aux=has_aux, **static_args) 2569 if has_aux: 2570 primals_out_flat, f_vjp_flat, aux = vjp_ret 2571 else: 2572 primals_out_flat, f_vjp_flat = vjp_ret 2573 primals_out = tree_unflatten(out_tree_store(), primals_out_flat) 2574 2575 def f_vjp(*cotangents_out): 2576 cotangents_out_flat, _ = tree_flatten(cotangents_out) 2577 cotangents_in_flat = f_vjp_flat(*cotangents_out_flat) 2578 2579 return tree_unflatten(in_tree, cotangents_in_flat) 2580 2581 return (primals_out, f_vjp, aux) if has_aux else (primals_out, f_vjp)
2587def backward_pass(program: Program, args: List[Any], cotangents: List[Any]) -> List[Any]: 2588 primal_env: Dict[Var, Any] = {} 2589 ct_env: Dict[Var, Any] = {} 2590 2591 def read_primal(x: Atom) -> Any: 2592 return primal_env.get(x, UndefinedPrimal(x.symval)) if type(x) is Var else x.val 2593 2594 def write_primal(v: Var, val: Any) -> None: 2595 if type(val) is not UndefinedPrimal: 2596 primal_env[v] = val 2597 2598 def read_cotangent(v: Var) -> Any: 2599 return ct_env.pop(v, backend.zeros(v.symval.shape, v.symval.dtype)) 2600 2601 def write_cotangent(x: Atom, ct: Any): 2602 if type(x) is Var and ct is not NullCotangent: 2603 ct_env[x] = (ct_env[x] + ct) if x in ct_env else ct 2604 2605 list_map(write_primal, program.in_binders, args) 2606 list_map(write_cotangent, program.outs, cotangents) 2607 for instruction in program.instructions[::-1]: 2608 primals_in = list_map(read_primal, instruction.inputs) 2609 cotangents_in = list_map(read_cotangent, instruction.out_binders) 2610 inp, params = primals_in, instruction.params 2611 cotangents_out = instruction.op.T(cotangents_in, *inp, **params) 2612 list_map(write_cotangent, instruction.inputs, cotangents_out) 2613 2614 ret = [read_cotangent(v) for v, x in list_zip(program.in_binders, args) if isinstance(x, UndefinedPrimal)] 2615 return ret
162 def decorated_function(*args, **kwargs) -> Any: 163 result = wrapper(*args, **kwargs) 164 cache_info = wrapper.cache_info() 165 166 dblog( 167 f"{fn.__name__}.{cache_info} {args.__hash__()}", 168 enable=backend.LOG_LRU, 169 ) 170 tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n ", ":\t") + "-" * 20 + "\n" 171 dblog(f"{tb}", enable=backend.LOG_LRU) 172 173 return result
2636def grad(f, argnums=(0,), argnames="", has_aux=False, return_value=False): 2637 f, rejit = (f, False) if not isinstance(f, jit) else (f.f, True) 2638 if isinstance(argnums, int): 2639 argnums = (argnums,) 2640 2641 def gfn(x, *xs, **static_args): 2642 vjp_ret = vjp(f, x, *xs, has_aux=has_aux, **static_args) 2643 if has_aux: 2644 y, f_vjp, aux = vjp_ret 2645 else: 2646 y, f_vjp = vjp_ret 2647 if np.shape(y) != (): 2648 raise TypeError("grad output must be 0-dim scalar with shape ()") 2649 gL_xs = f_vjp(backend.ones(())) 2650 gL_xs = tuple(gL_xs[i] for i in argnums) if len(argnums) > 1 else gL_xs[argnums[0]] 2651 if return_value: 2652 return ((y, aux), gL_xs) if has_aux else (y, gL_xs) 2653 else: 2654 return (gL_xs, aux) if has_aux else gL_xs 2655 2656 return jit(gfn) if rejit else gfn
2669def jit_partial_run(trace, tracers, *, program): 2670 in_unknowns = [not t.pval.is_known for t in tracers] 2671 program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns) 2672 known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) 2673 known_vals = [t.pval.const for t in known_tracers] 2674 outs1_res = backend.jit_op(*known_vals, program=program) 2675 outs1, res = split_list(outs1_res, len(program1.outs) - num_res) 2676 res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] 2677 outs2 = [PartialRunTraceTensor(trace, PartialValue.unknown(v.symval), None) for v in program2.outs] 2678 draft = InstructionDraft( 2679 backend.jit_op, 2680 res_tracers + unknown_tracers, 2681 dict(program=program2), 2682 [v.symval for v in program2.outs], 2683 map(weakref.ref, outs2), 2684 ) 2685 for t in outs2: 2686 t.draft = draft 2687 return merge_lists(out_unknowns, outs1, outs2)
2690class jit: 2691 def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None): 2692 if isinstance(static_argnames, str): 2693 static_argnames = tuple(static_argnames.split(" ")) 2694 assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames) 2695 self.f = f 2696 self.name = name if name is not None else self.f.__name__ 2697 self.static_argnames = static_argnames 2698 self.dynamic_axes = dynamic_axes 2699 2700 @classmethod 2701 def with_options(cls, **kwargs): 2702 return partial(cls, **kwargs) 2703 2704 @classmethod 2705 def get_jit_name(cls, args, static_args, prefix="jit", short=False): 2706 name = f"{prefix}_" 2707 if short: 2708 static_args_tup = tuple(static_args.items()) 2709 ids = repr(hash((prefix, args, static_args_tup)))[-4:] 2710 name = f"{prefix}_{ids}" 2711 else: 2712 for a in args: 2713 name += f"shape_{a.shape}_dtype_{a.dtype.name}_" 2714 for k, v in static_args.items(): 2715 name += f"{k}_{v}_" 2716 name = name.replace("(", "L") 2717 name = name.replace(")", "R") 2718 name = name.replace(",", "C") 2719 name = name.replace(" ", "") 2720 name = name.replace(".", "D") 2721 2722 return name 2723 2724 def get_program(self, *args, **static_args): 2725 sig = inspect.signature(self.f) 2726 if all("*" not in repr(v) for v in sig.parameters.values()): 2727 args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames] 2728 static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames] 2729 2730 if args: 2731 if len(args) > len(args_strs): 2732 assert static_args_strs 2733 args, rest = args[: len(args_strs)], args[len(args_strs) :] 2734 new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args} 2735 static_args = {**new_static_args, **static_args} 2736 else: 2737 args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)]) 2738 2739 symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args) 2740 static_args = tuple(static_args.items()) 2741 if self.name is None: 2742 self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}" 2743 program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name) 2744 return program, consts, out_tree 2745 2746 def __call__(self, *args, **static_args): 2747 program, consts, out_tree = self.get_program(*args, **static_args) 2748 args, in_tree = tree_flatten(args) 2749 outs = bind(backend.jit_op, *consts, *args, program=program) 2750 return tree_unflatten(out_tree, outs) 2751 2752 def lower(self, *args, **static_args): 2753 program, consts, out_tree = self.get_program(*args, **static_args) 2754 args, in_tree = tree_flatten(args) 2755 hashed_program = Hashed(program) 2756 num_consts = program.num_consts 2757 consts, args = args[:num_consts], args[num_consts:] 2758 hashed_consts = tuple(map(Hashed, consts)) 2759 jit_output = backend.jit_program(hashed_program, hashed_consts) 2760 return jit_output 2761 2762 def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs): 2763 if isinstance(args, Tensor): 2764 args, static_args = (args,), dict() 2765 elif not isinstance(args[-1], dict): 2766 assert all(isinstance(a, Tensor) for a in args) 2767 static_args = dict() 2768 else: 2769 args, static_args = args[:-1], args[-1] 2770 assert isinstance(args, (tuple, list)) and isinstance(static_args, dict) 2771 jit_output = self.lower(*args, **static_args) 2772 backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)
2691 def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None): 2692 if isinstance(static_argnames, str): 2693 static_argnames = tuple(static_argnames.split(" ")) 2694 assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames) 2695 self.f = f 2696 self.name = name if name is not None else self.f.__name__ 2697 self.static_argnames = static_argnames 2698 self.dynamic_axes = dynamic_axes
2704 @classmethod 2705 def get_jit_name(cls, args, static_args, prefix="jit", short=False): 2706 name = f"{prefix}_" 2707 if short: 2708 static_args_tup = tuple(static_args.items()) 2709 ids = repr(hash((prefix, args, static_args_tup)))[-4:] 2710 name = f"{prefix}_{ids}" 2711 else: 2712 for a in args: 2713 name += f"shape_{a.shape}_dtype_{a.dtype.name}_" 2714 for k, v in static_args.items(): 2715 name += f"{k}_{v}_" 2716 name = name.replace("(", "L") 2717 name = name.replace(")", "R") 2718 name = name.replace(",", "C") 2719 name = name.replace(" ", "") 2720 name = name.replace(".", "D") 2721 2722 return name
2724 def get_program(self, *args, **static_args): 2725 sig = inspect.signature(self.f) 2726 if all("*" not in repr(v) for v in sig.parameters.values()): 2727 args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames] 2728 static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames] 2729 2730 if args: 2731 if len(args) > len(args_strs): 2732 assert static_args_strs 2733 args, rest = args[: len(args_strs)], args[len(args_strs) :] 2734 new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args} 2735 static_args = {**new_static_args, **static_args} 2736 else: 2737 args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)]) 2738 2739 symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args) 2740 static_args = tuple(static_args.items()) 2741 if self.name is None: 2742 self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}" 2743 program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name) 2744 return program, consts, out_tree
2752 def lower(self, *args, **static_args): 2753 program, consts, out_tree = self.get_program(*args, **static_args) 2754 args, in_tree = tree_flatten(args) 2755 hashed_program = Hashed(program) 2756 num_consts = program.num_consts 2757 consts, args = args[:num_consts], args[num_consts:] 2758 hashed_consts = tuple(map(Hashed, consts)) 2759 jit_output = backend.jit_program(hashed_program, hashed_consts) 2760 return jit_output
2762 def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs): 2763 if isinstance(args, Tensor): 2764 args, static_args = (args,), dict() 2765 elif not isinstance(args[-1], dict): 2766 assert all(isinstance(a, Tensor) for a in args) 2767 static_args = dict() 2768 else: 2769 args, static_args = args[:-1], args[-1] 2770 assert isinstance(args, (tuple, list)) and isinstance(static_args, dict) 2771 jit_output = self.lower(*args, **static_args) 2772 backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)