slope.core

   1from pathlib import Path
   2import os
   3import json
   4from typing import (
   5    Callable,
   6    NamedTuple,
   7    Dict,
   8    Hashable,
   9    List,
  10    Any,
  11    Iterable,
  12    Sequence,
  13    Iterator,
  14    Type,
  15    Tuple,
  16    Optional,
  17    Union,
  18    Dict,
  19    Set,
  20    DefaultDict,
  21    Final,
  22)
  23import weakref
  24import types
  25from contextlib import contextmanager, ContextDecorator
  26import itertools
  27import weakref
  28import operator as operator_py
  29import numpy as np
  30import math
  31import inspect
  32from functools import partial, lru_cache
  33import mmap
  34import traceback
  35import importlib
  36import time
  37import cProfile
  38import pstats
  39
  40# =================================
  41#   Utils
  42# =================================
  43
  44
  45class Timing(ContextDecorator):
  46    def __init__(self, prefix="", on_exit=None, enabled=True):
  47        self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
  48
  49    def __enter__(self):
  50        self.st = time.perf_counter_ns()
  51
  52    def __exit__(self, *exc):
  53        self.et = time.perf_counter_ns() - self.st
  54        if self.enabled:
  55            print(f"{self.prefix}{self.et*1e-6:6.2f} ms" + (self.on_exit(self.et) if self.on_exit else ""))
  56
  57
  58def colored(st, color: Optional[str], background=False):
  59    return (
  60        f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m"
  61        if color is not None
  62        else st
  63    )  # replace the termcolor library with one line  # noqa: E501
  64
  65
  66def _format_fcn(fcn):
  67    return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
  68
  69
  70class Profiling(ContextDecorator):
  71    def __init__(self, enabled=True, sort="cumtime", frac=0.2, fn=None, ts=1):
  72        self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3 / ts
  73
  74    def __enter__(self):
  75        self.pr = cProfile.Profile()
  76        if self.enabled:
  77            self.pr.enable()
  78
  79    def __exit__(self, *exc):
  80        if self.enabled:
  81            self.pr.disable()
  82            if self.fn:
  83                self.pr.dump_stats(self.fn)
  84            stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
  85            for fcn in stats.fcn_list[0 : int(len(stats.fcn_list) * self.frac)]:  # type: ignore[attr-defined]
  86                (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn]  # type: ignore[attr-defined]
  87                scallers = sorted(callers.items(), key=lambda x: -x[1][2])
  88                print(
  89                    f"n:{num_calls:8d}  tm:{tottime*self.time_scale:7.2f}ms  tot:{cumtime*self.time_scale:7.2f}ms",
  90                    colored(_format_fcn(fcn), "yellow") + " " * (50 - len(_format_fcn(fcn))),
  91                    colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else "",
  92                )
  93
  94
  95def dblog(*msg, enable=True):
  96    if enable:
  97        print(*msg)
  98
  99
 100def unzip2(pairs) -> Tuple[List[Any], List[Any]]:
 101    lst1, lst2 = [], []
 102    for i1, i2 in pairs:
 103        lst1 += [i1]
 104        lst2 += [i2]
 105    return lst1, lst2
 106
 107
 108def list_map(f: Callable, *xs: Iterable) -> List[Any]:
 109    return list(map(f, *xs))
 110
 111
 112def list_zip(*args: List[Any]) -> List[Any]:
 113    fst, *rest = args = list_map(list, args)
 114    n = len(fst)
 115    for arg in rest:
 116        assert len(arg) == n
 117    return list(zip(*args))
 118
 119
 120def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
 121    assert not len(lst) % 2
 122    return split_list(lst, len(lst) // 2)
 123
 124
 125def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
 126    l1, l2 = iter(l1), iter(l2)
 127    out = [next(l2) if b else next(l1) for b in which]
 128    assert next(l1, None) is next(l2, None) is None
 129    return out
 130
 131
 132def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
 133    assert 0 <= n <= len(lst)
 134    return lst[:n], lst[n:]
 135
 136
 137def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
 138    assert 0 <= n <= len(lst)
 139    return lst[:n], lst[n:]
 140
 141
 142def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
 143    assert len(bs) == len(l)
 144    lists = lst1, lst2 = [], []
 145    for b, x in zip(bs, l):
 146        lists[b].append(x)
 147    return lst1, lst2
 148
 149
 150def lru_cache_verbose(
 151    maxsize: int = 100,
 152    typed: bool = False,
 153    tb_start: int = -12,
 154    tb_end: int = -7,
 155):
 156    def decorator(fn: Callable):
 157        @lru_cache(maxsize=maxsize, typed=typed)
 158        def wrapper(*args, **kwargs) -> Callable:
 159            return fn(*args, **kwargs)
 160
 161        def decorated_function(*args, **kwargs) -> Any:
 162            result = wrapper(*args, **kwargs)
 163            cache_info = wrapper.cache_info()
 164
 165            dblog(
 166                f"{fn.__name__}.{cache_info} {args.__hash__()}",
 167                enable=backend.LOG_LRU,
 168            )
 169            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
 170            dblog(f"{tb}", enable=backend.LOG_LRU)
 171
 172            return result
 173
 174        decorated_function.cache_info = wrapper.cache_info
 175        decorated_function.fn = fn
 176        return decorated_function
 177
 178    return decorator
 179
 180
 181def cuda_is_available():
 182    try:
 183        import subprocess
 184        import platform
 185
 186        cmd = f"nvidia-smi{'.exe' if platform.system == 'Windows' else ''}"
 187        result = subprocess.run([cmd], stdout=subprocess.PIPE)
 188        output = result.stdout.decode("utf-8")
 189        return True if "NVIDIA-SMI" in output else False
 190    except FileNotFoundError:
 191        return False
 192
 193
 194class Hashed:
 195    val: Any
 196
 197    def __init__(self, val):
 198        self.val = val
 199
 200    def __hash__(self) -> int:
 201        return hash((self.val,))
 202
 203    def __eq__(self, other):
 204        if isinstance(other, Hashed):
 205            if isinstance(self.val, Tensor) and isinstance(other.val, Tensor):
 206                # because Tensor.__eq__ already for Tensor.equal
 207                return id(self.val) == id(other.val)
 208            return self.val == other.val
 209        return False
 210
 211    def __repr__(self):
 212        return f"Hashed: {repr(self.val)}"
 213
 214
 215# =================================
 216#   Tensors
 217# =================================
 218
 219
 220class DType(NamedTuple):
 221    priority: int
 222    itemsize: int
 223    name: str
 224    mlir: str
 225    numpy: type
 226
 227    @property
 228    def format_code(self):
 229        return f"slope.{self.name}"
 230
 231    def __repr__(self):
 232        return f"<DType: {self.name}>"
 233
 234
 235class dtypes:
 236    float32: Final[DType] = DType(4, 4, "float32", "f32", np.float32)
 237    uint8: Final[DType] = DType(0, 1, "uint8", "u8", np.uint8)
 238    int8: Final[DType] = DType(0, 1, "int8", "i8", np.int8)
 239    bool: Final[DType] = DType(0, 1, "bool", "i1", bool)
 240    int32: Final[DType] = DType(1, 4, "int32", "i32", np.int32)
 241    int64: Final[DType] = DType(2, 8, "int64", "i64", np.int64)
 242    uint64: Final[DType] = DType(2, 8, "uint64", "ui64", np.uint64)
 243    float16: Final[DType] = DType(0, 2, "float16", "f16", np.float16)
 244    half = float16
 245    # bfloat16: Final[DType] = DType(0, 2, "bfloat16", "bf16", np.float16)
 246
 247    all_dtypes = (bool, float16, float32, int8, int32, int64, uint8, uint64)
 248    name_dtype_map = {k.name: k for k in all_dtypes}
 249    name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()}
 250    mlir_dtype_map = {k.mlir: k for k in all_dtypes}
 251    mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()}
 252
 253    @classmethod
 254    def is_int(cls, dtype):
 255        return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64)
 256
 257    @classmethod
 258    def is_float(cls, dtype):
 259        return dtype in (cls.float16, cls.bfloat16, cls.float32)
 260
 261
 262class Device(NamedTuple):
 263    name: str
 264    idx: int
 265
 266    @property
 267    def format_code(self):
 268        return f"'{self.name}:{self.idx}'"
 269
 270    def __repr__(self):
 271        return f"<Device: {self.format_code}>"
 272
 273
 274class devices:
 275    cpu: Final[Device] = Device("cpu", 0)
 276    metal: Final[Device] = Device("metal", 0)
 277    cuda0: Final[Device] = Device("cuda", 0)
 278    # TODO: programmatically define this class attrs to support other setup
 279    cuda = cuda0
 280    all_devices = (cpu, metal, cuda0)
 281    name_idx_device_map = {f"{k.name}:{k.idx}": k for k in all_devices}
 282    name_idx_device_map_inv = {v: k for k, v in name_idx_device_map.items()}
 283
 284
 285class TensorBuffer:
 286    def __init__(self, val):
 287        self.val = val
 288
 289
 290class Tensor:
 291    def __init__(self, val: TensorBuffer):
 292        assert isinstance(val, TensorBuffer)
 293        self.buf = val
 294
 295    @property
 296    def symval(self):
 297        return SymbolicTensor.like(self)
 298
 299    @property
 300    def default_dtype(self):
 301        return backend.default_dtype
 302
 303    def is_int(self) -> bool:
 304        return self.dtype in (
 305            dtypes.int8,
 306            dtypes.uint8,
 307            dtypes.uint64,
 308            dtypes.int32,
 309            dtypes.int64,
 310        )
 311
 312    def is_float(self) -> bool:
 313        return self.dtype in (dtypes.float16, dtypes.float32)
 314
 315    def is_unsigned(self) -> bool:
 316        return self.dtype is dtypes.uint8
 317
 318    def to_bool(self):
 319        return self.cast(dtypes.bool)
 320
 321    def short(self):
 322        return self.cast(dtypes.int8)
 323
 324    def int(self):
 325        return self.cast(dtypes.int32)
 326
 327    def long(self):
 328        return self.cast(dtypes.int64)
 329
 330    def half(self):
 331        return self.cast(dtypes.float16)
 332
 333    def float(self):
 334        return self.cast(dtypes.float32)
 335
 336    def __getattr__(self, attr):
 337        if attr in vars(backend.operator_set).keys():
 338            op = getattr(backend.operator_set, attr)
 339            return partial(op, self)
 340        elif attr in vars(backend.procedure_set).keys():
 341            procedure = getattr(backend.procedure_set, attr)
 342            assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}"
 343            return partial(procedure, self)
 344        else:
 345            return self.__getattribute__(attr)
 346
 347    def __getitem__(self, idx):
 348        return self.getitem(idx)
 349
 350    def __setitem__(self, idx, item):
 351        raise NotImplementedError
 352
 353    def str_short(self):
 354        return f"<Tensor: shape={self.shape}, dtype={self.dtype}>"
 355
 356    __neg__ = lambda self: self.neg()
 357    __add__ = lambda self, other: self.add(other)
 358    __radd__ = lambda self, other: self.add(other)
 359    __sub__ = lambda self, other: self.sub(other)
 360    __rsub__ = lambda self, other: self.sub.func(other, self)
 361    __mul__ = lambda self, other: self.mul(other)
 362    __rmul__ = lambda self, other: self.mul(other)
 363    __div__ = lambda self, other: self.div(other)
 364    __rdiv__ = lambda self, other: self.div.func(other, self)
 365    __truediv__ = __div__
 366    __truerdiv__ = __rdiv__
 367    __pow__ = lambda self, other: self.pow(other)
 368    __rpow__ = lambda self, other: self.pow.func(other, self)
 369    __matmul__ = lambda self, other: self.matmul(other)
 370    __rmatmul__ = lambda self, other: self.matmul.func(other, self)
 371    __invert__ = lambda self: self.invert()
 372    __eq__ = lambda self, other: self.equal(other)
 373    __ne__ = lambda self, other: self.not_equal(other)
 374    __ge__ = lambda self, other: self.greater_equal(other)
 375    __le__ = lambda self, other: self.less_equal(other)
 376    __gt__ = lambda self, other: self.greater(other)
 377    __lt__ = lambda self, other: self.less(other)
 378
 379    def __hash__(self):
 380        return id(self.val)
 381
 382    val = property(lambda self: self.buf.val)
 383
 384    def size(self, i):
 385        return self.shape[i]
 386
 387    @property
 388    def dtype(self):
 389        return backend.dtype_of(self)
 390
 391    @property
 392    def device(self):
 393        return backend.device_of(self)
 394
 395    def numpy(self, memmap=False):
 396        return backend.numpy_of(self, memmap)
 397
 398    @property
 399    def shape(self):
 400        return backend.shape_of(self)
 401
 402    @property
 403    def ndim(self):
 404        return len(self.shape)
 405
 406    def numel(self):
 407        return math.prod(self.shape)
 408
 409    def element_size(self):
 410        return self.dtype.itemsize
 411
 412    def nbytes(self):
 413        return self.numel() * self.element_size()
 414
 415    def __repr__(self):
 416        return f"<Tensor: val=\n{self.numpy()}\nshape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}>"
 417
 418
 419class SymbolicTensor(Tensor):
 420    def __init__(self, shape, dtype, device):
 421        assert isinstance(dtype, DType)
 422        self._shape = tuple(int(i) for i in shape)
 423        self._dtype = dtype
 424        self._device = device
 425
 426    @property
 427    def symval(self):
 428        return self
 429
 430    @property
 431    def val(self):
 432        raise RuntimeError(f"SymbolicTensor actually has no val, from {trace_stack[-1]=}, ")
 433
 434    @property
 435    def shape(self):
 436        return self._shape
 437
 438    @property
 439    def dtype(self):
 440        return self._dtype
 441
 442    @property
 443    def device(self):
 444        return self._device
 445
 446    def like(self, **overrides):
 447        shape = overrides.get("shape", self.shape)
 448        dtype = overrides.get("dtype", self.dtype)
 449        device = overrides.get("device", self.device)
 450        return SymbolicTensor(shape, dtype, device)
 451
 452    def str_short(self):
 453        return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]'
 454
 455    def __hash__(self):
 456        return hash((self.shape, self.dtype))
 457
 458    def __eq__(self, other):
 459        if type(self) != type(other):
 460            return False
 461        return (self.shape == other.shape) and (self.dtype == other.dtype)
 462
 463    def __repr__(self):
 464        return f"<SymbolicTensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device}>"
 465
 466
 467# =================================
 468#   Operator
 469# =================================
 470
 471
 472class Operator:
 473    def __init__(self, name, variadic_inputs=False, nary_outputs=False):
 474        self.name = name
 475        self.variadic_inputs = variadic_inputs
 476        self.nary_outputs = nary_outputs
 477        if self.variadic_inputs:
 478            self.reorg_args = self.reorg_args_nary
 479
 480    def __hash__(self):
 481        return hash(self.name)
 482
 483    def __eq__(self, other):
 484        if not isinstance(other, Operator):
 485            return False
 486        return self.name == other.name
 487
 488    def args_fixer(self, *args, **params):
 489        return args, params
 490
 491    def __call__(self, *args, **params):
 492        args, params = self.reorg_args(args, params)
 493        args, params = self.args_fixer(*args, **params)
 494        ret = bind(self, *args, **params)
 495        if not self.nary_outputs:
 496            ret = ret[0]
 497        return ret
 498
 499    def __repr__(self) -> str:
 500        return f"<{self.name}>"
 501
 502    def typecheck(self, *args, **params):
 503        raise NotImplementedError
 504
 505    def jvp(self, *args, **params):
 506        raise NotImplementedError
 507
 508    def T(self, *args, **params):
 509        raise NotImplementedError
 510
 511    def vmap(self, *args, **params):
 512        raise NotImplementedError
 513
 514    def reorg_args(self, args, params):
 515        sig = inspect.signature(self.typecheck)
 516        args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"]
 517        params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"]
 518
 519        if args:
 520            if len(args) > len(args_strs):
 521                args, rest = args[: len(args_strs)], args[len(args_strs) :]
 522                if params_strs:
 523                    new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params}
 524                    params = {**new_params, **params}
 525            else:
 526                args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)])
 527                assert len(args) == len(args_strs)
 528        return args, params
 529
 530    def reorg_args_nary(self, args, params):
 531        return args, params
 532
 533    def partial_run(self, trace, tracers, **params):
 534        tracers_in = [trace.instantiate_const(t) for t in tracers]
 535        symvals_in = [t.symval for t in tracers_in]
 536        symvals_out = self.typecheck(*symvals_in, **params)
 537        tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out]
 538        instruction = InstructionDraft(
 539            self,
 540            tracers_in,
 541            params,
 542            symvals_out,
 543            list_map(weakref.ref, tracers_out),
 544        )
 545        for t in tracers_out:
 546            t.draft = instruction
 547        return tracers_out
 548
 549    def partial_run_instruction(self, unks_in, instruction):
 550        if any(unks_in):
 551            instruction1 = None
 552            instruction2 = Instruction(
 553                instruction.op,
 554                instruction.inputs,
 555                instruction.params,
 556                instruction.out_binders,
 557            )
 558            unks_out = [True for i in instruction.out_binders]
 559            res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)]
 560        else:
 561            instruction1 = instruction
 562            instruction2 = None
 563            unks_out = [False for i in instruction.out_binders]
 564            res = None
 565
 566        return instruction1, instruction2, unks_out, res
 567
 568
 569class MetaOperator(Operator):
 570    def meta_impl(self, *args, **kwargs):
 571        raise NotImplementedError
 572
 573
 574class UnaryOperator(Operator):
 575    def vmap(self, x, *, dim_size, vals_in, dims_in, **params):
 576        (x,), (x_bdim,) = vals_in, dims_in
 577        return [self(x, **params)], [x_bdim]
 578
 579    def typecheck(self, x, **params):
 580        return [SymbolicTensor.like(x)]
 581
 582    def jvp(self, primals, tangents, **params):
 583        (x,), (x_dot,) = primals, tangents
 584        return [self(x, **params)], [self(x_dot, **params)]
 585
 586
 587class BinaryOperator(Operator):
 588    boolean_output = False
 589
 590    def args_fixer(self, x, w, **params):
 591        if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal:
 592            assert x.shape == w.shape
 593            return (x, w), params
 594
 595        if type(x) in TraceTensor.PYTHON_TYPES:
 596            x = backend.full(shape=(), fill_value=x, dtype=w.dtype)
 597        elif type(w) in TraceTensor.PYTHON_TYPES:
 598            w = backend.full(shape=(), fill_value=w, dtype=x.dtype)
 599
 600        shape_delta = x.ndim - w.ndim
 601        if shape_delta > 0:
 602            w = w.reshape((1,) * shape_delta + w.shape)
 603        elif shape_delta < 0:
 604            x = x.reshape((1,) * -shape_delta + x.shape)
 605
 606        shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)])
 607        if x.shape != shape_ret:
 608            x = x.expand(shape_ret)
 609        if w.shape != shape_ret:
 610            w = w.expand(shape_ret)
 611
 612        if type(x) is Tensor and isinstance(w, TraceTensor):
 613            x = w._trace.pure(x)
 614        elif type(w) is Tensor and isinstance(x, TraceTensor):
 615            w = x._trace.pure(w)
 616        # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html
 617        if x.dtype != w.dtype:
 618            # {int, bool} -> float
 619            if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype):
 620                if dtypes.is_float(w.dtype):
 621                    x = x.cast(w.dtype)
 622                elif dtypes.is_float(x.dtype):
 623                    w = w.cast(x.dtype)
 624            # bool -> int
 625            elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype):
 626                if dtypes.is_int(w.dtype):
 627                    x = x.cast(w.dtype)
 628                elif dtypes.is_int(x.dtype):
 629                    w = w.cast(x.dtype)
 630            else:  # TODO: fine-grained type promotions
 631                raise NotImplementedError("No other type promotion rules")
 632
 633        return (x, w), params
 634
 635    def vmap(self, dim_size, vals_in, dims_in, **params):
 636        (x, w), (x_bdim, w_bdim) = vals_in, dims_in
 637        if x_bdim != w_bdim:
 638            if x_bdim is None:
 639                x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim)
 640                x_bdim = w_bdim
 641            else:
 642                w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim)
 643        return [self(x, w, **params)], [x_bdim]
 644
 645    def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]:
 646        if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)):
 647            raise TypeError
 648        symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype)
 649        symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype)
 650        if x.dtype != y.dtype:
 651            raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})")
 652        if symx == symy:
 653            return [symx]
 654        shape_delta = len(symx.shape) - len(symy.shape)
 655        if shape_delta > 0:
 656            symy = symy.like(shape=(1,) * shape_delta + symy.shape)
 657        elif shape_delta < 0:
 658            symx = symx.like(shape=(1,) * -shape_delta + symx.shape)
 659        if symx == symy:
 660            return [symx]
 661        else:
 662            shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)])
 663            if symx.shape != shape_ret:
 664                symx = symx.like(shape=shape_ret)
 665            if symy.shape != shape_ret:
 666                symy = symx.like(shape=shape_ret)
 667            if symx != symy:
 668                raise TypeError(f"symx ({symx}) != symy ({symy})")
 669            return [symx]
 670
 671    def jvp(self, primals, tangents, **params):
 672        (x, w), (x_dot, w_dot) = primals, tangents
 673        return [self(x, w, **params)], [self(x_dot, w_dot, **params)]
 674
 675    def T(self, cotangents, x, w):
 676        (gL_y,) = cotangents
 677        if self.boolean_output:
 678            gL_y = gL_y.cast(x.dtype)
 679        if isinstance(x, UndefinedPrimal):
 680            return [gL_y, NullCotangent]
 681        elif isinstance(w, UndefinedPrimal):
 682            return [NullCotangent, gL_y]
 683        else:
 684            raise ValueError
 685
 686
 687class ReduceOperator(Operator):
 688    def args_fixer(self, x, *, dim=None, keepdim=False):
 689        if dim is None:
 690            dim = tuple(range(x.ndim))
 691        elif isinstance(dim, int):
 692            dim = (dim,)
 693        dim = tuple(a if a >= 0 else a + len(x.shape) for a in dim)
 694        return (x,), dict(dim=dim, keepdim=keepdim)
 695
 696    def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim):
 697        (x,), (x_bdim,) = vals_in, dims_in
 698        dim = tuple(a + (x_bdim <= a) for a in dim)
 699        out_bdim = x_bdim - sum(a < x_bdim for a in dim)
 700        return [self(x, dim=dim, keepdim=keepdim)], [out_bdim]
 701
 702    def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]:
 703        dim = list(set([a + len(x.shape) if a < 0 else a for a in dim]))
 704        if keepdim:
 705            new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)]
 706        else:
 707            new_shape = [d for i, d in enumerate(x.shape) if i not in dim]
 708        return [SymbolicTensor.like(x, shape=tuple(new_shape))]
 709
 710
 711class InitOperator(Operator):
 712    def vmap(self, dim_size, vals_in, dims_in, **params):
 713        (x_bdim,) = dims_in
 714        y = self(**params)
 715        y = y.unsqueeze(x_bdim)
 716        return [y], [x_bdim]
 717
 718    def jvp(self, primals, tangents, **params):
 719        y = self(**params)
 720        y_dot = NullCotangent(y.symval)
 721        return [y], [y_dot]
 722
 723    def T(self, cotangents, **params):
 724        return [NullCotangent(cotangents[0])]
 725
 726
 727class ShapeOperator(Operator):
 728    pass
 729
 730
 731class GeneralReduceOperator(Operator):
 732    pass
 733
 734
 735class OperatorSet:
 736    def __init__(self):
 737        self.register("jit_op")(JitOp)
 738
 739    def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()):
 740        def wrap(op_cls):
 741            assert name not in vars(self)
 742            op = op_cls(name, variadic_inputs, nary_outputs)
 743            setattr(self, name, op)
 744            for a in aliases:
 745                setattr(self, a, op)
 746            return op_cls
 747
 748        return wrap
 749
 750
 751class ProcedureSet:
 752    def register(self, aliases=()):
 753        def wrap(f):
 754            assert f.__name__ not in vars(self)
 755            setattr(self, f.__name__, f)
 756            for a in aliases:
 757                setattr(self, a, f)
 758            return f
 759
 760        return wrap
 761
 762
 763class CodegenOutput(NamedTuple):
 764    code_lines: List[str]
 765    fn_defs: Dict[str, List[str]]
 766    in_binders: List["ProgramEnvVar"]
 767    outs: List["ProgramEnvVar"]
 768
 769
 770class Backend:
 771    LOG_LRU = int(os.environ.get("LOG_LRU", 0))
 772    LOG_JIT = int(os.environ.get("LOG_JIT", 0))
 773    LOG_TREE = int(os.environ.get("LOG_TREE", 0))
 774    LOG_BACKEND = int(os.environ.get("LOG_BACKEND", 0))
 775    LOG_PROGRAM = int(os.environ.get("LOG_PROGRAM", 0))
 776    LOG_INIT = int(os.environ.get("LOG_INIT", 1))
 777    device_var = os.environ.get("DEFAULT_DEVICE", "cpu:0")
 778    if device_var[-2] != ":":
 779        device_var += ":0"
 780    DEFAULT_DEVICE = devices.name_idx_device_map[device_var]
 781    DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")]
 782    dtype_for_indices: DType = None  # need to override
 783
 784    def __init__(
 785        self,
 786        operator_set: OperatorSet,
 787        procedure_set: ProcedureSet,
 788    ):
 789        self.operator_set = operator_set
 790        self.procedure_set = procedure_set
 791        self.node_types = dict()
 792        self.impls = dict()
 793        self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple")
 794        self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list")
 795        self.register_node(
 796            dict,
 797            lambda d: list_map(tuple, unzip2(sorted(d.items()))),
 798            lambda keys, vals: dict(list_zip(keys, vals)),
 799            "dict",
 800        )
 801        self.register_node(
 802            UndefinedPrimal,
 803            lambda u: (u.symval, ()),
 804            lambda symval, _: UndefinedPrimal(symval),
 805            "UndefinedPrimal",
 806        )
 807
 808    def set_impl(self, op: Union[types.LambdaType, types.FunctionType]):
 809        def set_impl_(fn):
 810            self.impls[op] = types.MethodType(fn, self)
 811
 812        return set_impl_
 813
 814    def register_node(self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None:
 815        if name is None:
 816            name = str(ty)
 817        self.node_types[ty] = NodeType(name, to_iter, from_iter)
 818
 819    def __getattr__(self, attr):
 820        try:
 821            dblog(
 822                f"Looking {self}.{attr} in operator_set",
 823                enable=backend.LOG_BACKEND,
 824            )
 825            return getattr(self.operator_set, attr)
 826        except:
 827            pass
 828        try:
 829            dblog(
 830                f"Looking {self}.{attr} in procedure_set",
 831                enable=backend.LOG_BACKEND,
 832            )
 833            return getattr(self.procedure_set, attr)
 834        except:
 835            pass
 836        dblog(
 837            f"Fallback to default {self} getattribute",
 838            enable=backend.LOG_BACKEND,
 839        )
 840        super().__getattribute__(attr)
 841
 842    def tensor(
 843        self,
 844        val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
 845        dtype: Optional[Any] = None,
 846        device=None,
 847    ):
 848        if isinstance(val, TensorBuffer):
 849            return Tensor(val)
 850        elif isinstance(val, Tensor):
 851            return val
 852        if type(val) is bytes:
 853            val = np.frombuffer(val, dtype=dtype)
 854        return self.from_numpy(val, dtype, device)
 855
 856    def symbolic_tensor(
 857        self,
 858        shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
 859        dtype: Optional[Any] = None,
 860        device=None,
 861    ):
 862        dtype = dtype or self.DEFAULT_DTYPE
 863        device = device or self.DEFAULT_DEVICE
 864        return SymbolicTensor(shape, dtype, device)
 865
 866    def seed(self, seed):
 867        raise NotImplementedError
 868
 869    @property
 870    def default_dtype_value(self):
 871        return self.dtype_map[backend.DEFAULT_DTYPE]
 872
 873    def set_method(self, method):
 874        setattr(self, method.__name__, types.MethodType(method, self))
 875
 876    def from_numpy(self, val, device):
 877        raise NotImplementedError
 878
 879    def numpy_of(self, tensor):
 880        raise NotImplementedError
 881
 882    def device_of(self, tensor):
 883        raise NotImplementedError
 884
 885    def shape_of(self, tensor):
 886        raise NotImplementedError
 887
 888    def dtype_of(self, tensor):
 889        raise NotImplementedError
 890
 891    @lru_cache_verbose()
 892    def jit_program(
 893        self,
 894        hashed_program: Hashed,
 895        hashed_consts: Tuple[Hashed, ...],
 896    ):
 897        program: Program = hashed_program.val
 898        typecheck_program(program)
 899        consts = [x.val for x in hashed_consts]
 900        in_symvals = [v.symval for v in program.in_binders[len(consts) :]]
 901        codegen_output: CodegenOutput = self.codegen(program, consts + in_symvals, fn_name="main")
 902        fn, code = self.compile(codegen_output)
 903        jit_output = JitOutput(program, codegen_output, fn, code, consts)
 904        return jit_output
 905
 906    def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
 907        "Returns compiler IR from the Program"
 908        raise NotImplementedError
 909
 910    def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
 911        "Compiles compiler IR to a Python callable function"
 912        raise NotImplementedError
 913
 914    def export(self, jit_output, *args, **params):
 915        raise NotImplementedError
 916
 917    def load(self, path, single_key="_tensor"):
 918        with open(path, mode="rb") as f:
 919            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m:
 920                json_len = np.int64(m[0])
 921                start = 8 + json_len
 922                metadata = json.loads(m[8:start])
 923                ret = {}
 924                for k, v in metadata.items():
 925                    if k != "__metadata__":
 926                        dtype = Tensor.mlir_dtype_map[(v["dtype"])]
 927                        data_start = start + v["data_offsets"][0]
 928                        data_end = start + v["data_offsets"][1]
 929                        t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy())
 930                        t = backend.tensor(t_np, dtype=dtype)
 931                        t = t.reshape(tuple(v["shape"]))
 932                        ret[k] = t
 933                if len(ret) == 1 and single_key in ret.keys():
 934                    return ret[single_key]
 935                return ret
 936
 937    def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"):
 938        if isinstance(tensors, Tensor):
 939            tensors = {single_key: tensors}
 940        else:
 941            assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items())
 942
 943        metadata, offset = {}, 0
 944        for k, v in tensors.items():
 945            metadata[k] = {
 946                "dtype": v.dtype.mlir,
 947                "shape": list(v.shape),
 948                "data_offsets": [offset, offset + v.nbytes()],
 949            }
 950            offset += v.nbytes()
 951        j = json.dumps(metadata, separators=(",", ":"))
 952        Path(path).unlink(missing_ok=True)
 953        jbytes = j.encode("utf-8")
 954        start = 8 + len(jbytes)
 955        with open(path, mode="wb") as f:  # make empty file, fill with enough space
 956            f.write(b"\x00" * (start + offset))
 957        with open(path, mode="r+b") as f:
 958            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m:
 959                m[0:8] = np.int64(len(j)).tobytes()
 960                m[8:start] = jbytes
 961                for t, tm in zip(tensors.values(), metadata.values()):
 962                    data_start, data_end = tm["data_offsets"]
 963                    m[start + data_start : start + data_end] = t.numpy().tobytes()
 964
 965
 966# =================================
 967#   Program
 968# =================================
 969
 970
 971class Var:
 972    def __init__(self, symval):
 973        self.symval = symval
 974        self.val = None
 975
 976
 977class Lit:
 978    def __init__(self, val):
 979        self.symval = SymbolicTensor.like(get_symval(val))
 980        self.val = val
 981
 982
 983Atom = Union[Var, Lit]
 984
 985
 986class Instruction(NamedTuple):
 987    op: Operator
 988    inputs: List[Atom]
 989    params: Dict[str, Any]
 990    out_binders: List[Atom]
 991
 992
 993class ProgramEnvVar(NamedTuple):
 994    name: str
 995    symval: SymbolicTensor
 996    is_const: bool = False
 997
 998    @property
 999    def shape(self):
1000        return self.symval.shape
1001
1002    @property
1003    def dtype(self):
1004        return self.symval.dtype
1005
1006    @property
1007    def device(self):
1008        return self.symval.device
1009
1010    @property
1011    def ndim(self):
1012        return self.symval.ndim
1013
1014    def numpy(self):
1015        return self.symval.numpy()
1016
1017    def __repr__(self):
1018        return f"<ProgramEnvVar: name={self.name}, symval={self.symval}>"
1019
1020    str_short = __repr__
1021
1022
1023class Program:
1024    def __init__(
1025        self,
1026        in_binders: Any,
1027        instructions: Tuple[Instruction],
1028        outs: Any,
1029        num_consts: int = 0,
1030        static_args: Any = (),
1031        name: str = "my_program",
1032        indent_amount=4,
1033    ):
1034        self.in_binders: Any = in_binders
1035        self.outs: Any = outs
1036        self.instructions = self.prune_instructions(instructions, outs)
1037        self.num_consts: int = num_consts
1038        self.static_args = static_args
1039        self.name: str = name
1040        self.indent_amount: int = indent_amount
1041
1042        self.env: Dict[ProgramEnvVar, Any] = dict()
1043        for inb in self.in_binders:
1044            prefix = "x" if type(inb.symval) is SymbolicTensor else "c"
1045            idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1046            self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False)
1047        for instruction in self.instructions:
1048            if len(instruction.out_binders) == 0:
1049                continue
1050            for outb in instruction.out_binders:
1051                prefix = "y" if outb in self.outs else "z"
1052                idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1053                self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval)
1054        self.curr_repr = repr(self)
1055
1056    def pprint_shape(self, symval, scalar_as_empty_array=False):
1057        xdtype = symval.dtype.mlir
1058        if len(symval.shape) > 0:
1059            xshape = f"{', '.join((repr(i) for i in symval.shape))}"
1060            return f"[{xshape}, {xdtype}]"
1061        else:
1062            return f"[{xdtype}]"
1063
1064    def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False):
1065        in_code = ", ".join(self.pprint_shape(t) for t in in_symvals)
1066        in_code = f"({in_code})" if len(in_symvals) > 1 else in_code
1067        out_code = ", ".join(self.pprint_shape(t) for t in out_symvals)
1068        out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code
1069        typing_code = f"{in_code} -> {out_code}"
1070        return typing_code
1071
1072    def __repr__(self):
1073        fn_defs = self.instructions_as_code(self, dict())
1074        return "\n".join(line for code_lines in fn_defs.values() for line in code_lines)
1075
1076    def save(self, *args, dir_path="/tmp/slope_program", dry_run=False):
1077        os.makedirs(dir_path, exist_ok=True)
1078        head_code_lines = [f"import slope # backend={backend.__class__.__name__}"]
1079        fn_defs = self.instructions_as_code(self, dict())
1080        in_binders_vars = [self.env[i] for i in self.in_binders]
1081        for i in range(len(self.in_binders)):
1082            ibv = in_binders_vars[i]
1083            if ibv.is_const:
1084                const_filename = f"{ibv.name}.safetensors"
1085                const_path = os.path.join(dir_path, f"{const_filename}")
1086                if not dry_run:
1087                    backend.save(args[i], const_path)
1088                dblog(
1089                    f"Saved {ibv.name} at {const_path}",
1090                    enable=backend.LOG_BACKEND,
1091                )
1092                head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""]
1093        head_code_lines += [""]
1094        code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines])
1095        dblog(
1096            f"Contents of {self.name}:\n```\n{code}\n```",
1097            enable=backend.LOG_BACKEND,
1098        )
1099        program_path = os.path.join(dir_path, "main.py")
1100        if not dry_run:
1101            with open(program_path, "w") as f:
1102                f.write(code)
1103        dblog(
1104            f"Saved program {self.name} at {program_path}",
1105            enable=backend.LOG_BACKEND,
1106        )
1107        ls_contents = "\n\t".join(os.listdir(dir_path))
1108        dblog(
1109            f"Contents of {dir_path}:\n\t{ls_contents}",
1110            enable=backend.LOG_BACKEND,
1111        )
1112
1113    def __hash__(self):
1114        return hash(self.curr_repr)
1115
1116    def __eq__(self, other):
1117        return self is other
1118
1119    @classmethod
1120    def instructions_as_code(cls, program, fn_defs):
1121        def indent(code, indent_amount):
1122            spaces = " " * (len(code) - len(code.lstrip()))
1123            spaces += " " * indent_amount
1124            return "\n".join([spaces + line for line in code.strip().split("\n")])
1125
1126        in_binders_vars = [program.env[i] for i in program.in_binders]
1127        body_code_lines = []
1128        for instruction in program.instructions:
1129            if len(instruction.out_binders) == 0:
1130                continue
1131            params = instruction.params.copy()
1132            for param_name, param in params.items():
1133                if isinstance(param, Program):
1134                    sub_program = param
1135                    fn_defs = cls.instructions_as_code(sub_program, fn_defs)
1136                    program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1137                    params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]"
1138                if isinstance(param, DType):
1139                    params[param_name] = f"slope.{param.name}"
1140            param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items())
1141            in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1142            out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders)
1143            sig = program.pprint_sig(
1144                [program.env[x].symval for x in instruction.inputs],
1145                [program.env[y].symval for y in instruction.out_binders],
1146            )
1147            line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}"""
1148            body_code_lines += [indent(line, program.indent_amount)]
1149
1150        fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars])
1151        # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args])
1152        out_vars = [program.env[o] for o in program.outs]
1153        fn_sig = program.pprint_sig(
1154            [i.symval for i in in_binders_vars],
1155            [o.symval for o in out_vars],
1156        )
1157        head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"]
1158        out_str = ", ".join([f"{o.name}" for o in out_vars])
1159        tail_code_line = [indent(f"return {out_str}", program.indent_amount)]
1160        code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"]
1161
1162        fn_defs[program.name] = code_lines
1163        return fn_defs
1164
1165    @staticmethod
1166    def prune_instructions(instructions, outs):
1167        graph = dict()
1168        for instruction in instructions:
1169            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1170            for parent in parent_nodes:
1171                if parent not in graph:
1172                    graph[parent] = set()
1173                for child in child_nodes:
1174                    graph[parent].add(child)
1175        visited_from_terminal = set()
1176
1177        def dfs(node, visited):
1178            visited.add(node)
1179            if node in graph:
1180                for neighbor in graph[node]:
1181                    if neighbor not in visited:
1182                        dfs(neighbor, visited)
1183
1184        for terminal_node in outs:
1185            dfs(terminal_node, visited_from_terminal)
1186        unreachable_nodes = set(graph.keys()) - visited_from_terminal
1187
1188        instructions_to_prune = []
1189        for instruction in instructions:
1190            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1191            if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes):
1192                instructions_to_prune += [instruction]
1193        new_instructions = [inst for inst in instructions if inst not in instructions_to_prune]
1194        if backend.LOG_PROGRAM:
1195            LI = len(instructions)
1196            LNI = len(new_instructions)
1197            DIFF = LI - LNI
1198            UN = len(unreachable_nodes)
1199            dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}")
1200        return new_instructions
1201
1202
1203class ProgramType(NamedTuple):
1204    in_types: Tuple[SymbolicTensor]
1205    out_types: Tuple[SymbolicTensor]
1206
1207    def __repr__(self):
1208        in_types = ", ".join(symval.str_short() for symval in self.in_types)
1209        out_types = ", ".join(symval.str_short() for symval in self.out_types)
1210        return f"({in_types}) -> ({out_types})"
1211
1212
1213# =================================
1214#   Tracer and Trace
1215# =================================
1216
1217
1218class Empty:
1219    pass
1220
1221
1222empty = Empty()
1223
1224
1225class Store:
1226    val = empty
1227
1228    def set_value(self, val):
1229        assert self.val is empty
1230        self.val = val
1231
1232    def __call__(self):
1233        return self.val
1234
1235
1236class NodeType(NamedTuple):
1237    name: str
1238    flatten: Callable
1239    unflatten: Callable
1240
1241
1242class TreeDef(NamedTuple):
1243    node_type: NodeType
1244    node_metadata: Hashable
1245    child_treedefs: Tuple["TreeDef", ...]
1246
1247    def __repr__(self):
1248        ret = self.tree_repr(self)
1249        return ret
1250
1251    def tree_repr(self, tree, indent="  ", prefix="", last=True):
1252        ret = ""
1253
1254        def _tree_repr(tree, indent, prefix, last):
1255            nonlocal ret
1256            if isinstance(tree, TreeDef):
1257                ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n'
1258                for i, item in enumerate(tree.child_treedefs):
1259                    new_prefix = prefix + (indent if not last else "   ")
1260                    new_last = i == len(tree.child_treedefs) - 1
1261                    _tree_repr(item, indent, new_prefix, new_last)
1262            else:
1263                ret += f'{prefix} {("└─" if last else "├─")} {tree}\n'
1264
1265        _tree_repr(tree, indent="  ", prefix="", last=True)
1266        return ret
1267
1268    @property
1269    def num_leaves(self):
1270        def get_num_leaves(x):
1271            if isinstance(x, Leaf):
1272                return 1
1273            else:
1274                return sum(get_num_leaves(sub_x) for sub_x in x.child_treedefs)
1275
1276        return sum(get_num_leaves(x) for x in self.child_treedefs)
1277
1278
1279class Leaf:
1280    def __init__(self, val):
1281        if hasattr(val, "shape"):
1282            val = SymbolicTensor.like(val)
1283        self.val = val
1284
1285    def __repr__(self):
1286        ret = self.val.str_short() if isinstance(self.val, SymbolicTensor) else repr(self.val)
1287        return f"<Leaf: {ret}>"
1288
1289    def __hash__(self):
1290        return hash(self.val)
1291
1292    def __eq__(self, other):
1293        return True  # make TreeDef __eq__ don't care Leaf
1294        # if isinstance(other, Leaf): # TODO: test above assumption
1295        #     return self.val == other.val
1296
1297
1298# =================================
1299#   jit operator
1300# =================================
1301
1302
1303class JitOutput:
1304    def __init__(self, program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any]):
1305        super().__init__()
1306        self.program = program
1307        self.code = code
1308        self.codegen_output = codegen_output
1309        self.fn: Callable = fn
1310        self.consts = consts
1311
1312    def __call__(self, *args, **params):
1313        args, in_tree = tree_flatten(args)
1314        args = tree_map(lambda a: a.val if isinstance(a, Tensor) else a, args)
1315        try:
1316            outs = self.fn(*args, **params)
1317            if not isinstance(outs, tuple):  # TODO: IREE FunctionInvoker destructure 1-tuple, need to undo
1318                outs = (outs,)
1319        except Exception as e:
1320            dblog(self.code, enable=backend.LOG_JIT)
1321            raise
1322        return [backend.tensor(TensorBuffer(o)) for o in outs]
1323
1324
1325class JitOp(MetaOperator):
1326    def meta_impl(self, *args, program: Program, **_):
1327        hashed_program = Hashed(program)
1328        num_consts = program.num_consts
1329        consts, args = args[:num_consts], args[num_consts:]
1330        hashed_consts = tuple(map(Hashed, consts))
1331        jit_output = backend.jit_program(hashed_program, hashed_consts)
1332        ret = jit_output(*consts, *args)
1333        return ret
1334
1335    def reorg_args(self, args, params):
1336        return args, params
1337
1338    def typecheck(self, *in_types, program: Program):
1339        program_type = typecheck_program(program)
1340        if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)):
1341            ret = "Type mismatch program.in_types vs in_types:\n"
1342            for i, j in zip(program_type.in_types, in_types):
1343                ret += f"{i}, {j}, {i == j}"
1344            raise TypeError(ret)
1345        return program_type.out_types
1346
1347    def vmap(self, dim_size, vals_in, dims_in, program: Program):
1348        program, consts = vmap_program(program, dim_size, tuple(dims_in))
1349        outs = self(*consts, *vals_in, program=program)
1350        if not isinstance(outs, tuple):
1351            outs = (outs,)
1352        return outs, [0] * len(outs)
1353
1354    def jvp(self, primals, tangents, *, program):
1355        new_program, new_consts = jvp_program(program)
1356        outs = bind(
1357            self,
1358            *new_consts,
1359            *primals,
1360            *tangents,
1361            program=new_program,
1362        )
1363        n = len(outs) // 2
1364        primals_out, tangents_out = outs[:n], outs[n:]
1365        return primals_out, tangents_out
1366
1367    def T(self, cotangents, *invals, program):
1368        undef_primals = [isinstance(x, UndefinedPrimal) for x in invals]
1369        transposed_program, new_consts = transpose_program(program, tuple(undef_primals))
1370
1371        residuals, _ = partition_list(undef_primals, invals)
1372        outs = bind(
1373            self,
1374            *new_consts,
1375            *residuals,
1376            *cotangents,
1377            program=transposed_program,
1378        )
1379        outs = iter(outs)
1380
1381        return [next(outs) if undef else None for undef in undef_primals]
1382
1383    def partial_run(self, trace, tracers, *, program):
1384        in_unknowns = [not t.pval.is_known for t in tracers]
1385        program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns)
1386        known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
1387        known_vals = [t.pval.const for t in known_tracers]
1388        outs1_res = bind(backend.jit_op, *known_vals, program=program1)
1389        outs1, res = split_list(outs1_res, len(program1.outs) - num_res)
1390        res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
1391        outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs]
1392        instruction = InstructionDraft(
1393            self,
1394            res_tracers + unknown_tracers,
1395            dict(program=program2),
1396            [v.symval for v in program2.outs],
1397            list_map(weakref.ref, outs2),
1398        )
1399        for t in outs2:
1400            t.draft = instruction
1401
1402        return merge_lists(out_unknowns, outs1, outs2)
1403
1404    def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]:
1405        program = instruction.params["program"]
1406        program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in)
1407        ins1, ins2 = partition_list(unks_in, instruction.inputs)
1408        out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders)
1409        res = [Var(v.symval) for v in program2.in_binders[:num_res]]
1410        instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res)
1411        instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2)
1412        return instruction1, instruction2, out_unknowns, res
1413
1414
1415class MainTrace(NamedTuple):
1416    level: int
1417    trace_type: Type["Trace"]
1418    global_data: Optional[Any]
1419
1420
1421class Trace:
1422    main: MainTrace
1423
1424    def __init__(self, main: MainTrace) -> None:
1425        self.main = main
1426
1427    def pure(self, val):
1428        raise NotImplementedError
1429
1430    def run_op(self, op, tracers, params):
1431        raise NotImplementedError
1432
1433
1434class RunTrace(Trace):
1435    pure = lambda self, x: x
1436
1437    def run_op(self, op: Operator, args, params):
1438        if isinstance(op, MetaOperator):
1439            args, params = op.reorg_args(args, params)
1440            args, params = op.args_fixer(*args, **params)
1441            ret = op.meta_impl(*args, **params)
1442        else:
1443            fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params)
1444            # with Timing(f"RUN {op}"):ret = jit(
1445            ret = jit(
1446                fn,
1447                static_argnames=("params",),
1448                name=jit.get_jit_name(args, params, op.name),
1449            )(*args, **params)
1450
1451        return ret
1452
1453    @staticmethod
1454    @lru_cache_verbose()
1455    def get_fn(op, *symval_args, **params):
1456        def fn(*args, **params):
1457            return [op(*args, **params)]
1458
1459        return fn
1460
1461
1462class SymbolicRunTrace(Trace):
1463    # pure = lambda self, x: x
1464    def pure(self, val: Any) -> SymbolicTensor:
1465        return val.symval
1466
1467    def run_op(self, op, tracers, params):
1468        symvals_in = tree_map(lambda x: x.symval, tracers)
1469        symvals_out = op.typecheck(*symvals_in, **params)
1470        return symvals_out
1471
1472
1473class TraceTensor(Tensor):
1474    PYTHON_TYPES = {
1475        bool,
1476        int,
1477        float,
1478    }
1479    _trace: "Trace"
1480
1481    def __init__(self, *args, **kwargs):
1482        raise NotImplementedError
1483
1484    symval = property(lambda self: get_symval(self.val))
1485    dtype = property(lambda self: self.symval.dtype)
1486    shape = property(lambda self: self.symval.shape)
1487    device = property(lambda self: self.symval.device)
1488
1489    @property
1490    def val(self):
1491        raise NotImplementedError
1492
1493    def __str__(self):
1494        return repr(self)
1495
1496    def full_lower(self):
1497        return self
1498
1499    @property
1500    def ndim(self):
1501        return len(self.shape)
1502
1503    def __repr__(self):
1504        return f"{self.__class__.__name__}({repr(self.symval)})"
1505
1506
1507class VMapTraceTensor(TraceTensor):
1508    def __init__(self, trace, val, vmap_dim):
1509        self._trace = trace
1510        self._val = val
1511        self.vmap_dim = vmap_dim
1512
1513    @property
1514    def val(self):
1515        return self._val
1516
1517    @property
1518    def symval(self):
1519        symval = get_symval(self.val)
1520        if self.vmap_dim is None:
1521            return symval
1522        else:
1523            shape = list(symval.shape)
1524            del shape[self.vmap_dim]
1525            return symval.like(shape=tuple(shape))
1526
1527    def full_lower(self):
1528        if self.vmap_dim is None:
1529            return full_lower(self.val)
1530        else:
1531            return self
1532
1533
1534class VMapTrace(Trace):
1535    pure = lambda self, val: VMapTraceTensor(self, val, None)
1536
1537    @property
1538    def dim_size(self):
1539        return self.main.global_data
1540
1541    def run_op(self, op, tracers, params):
1542        vals_in, bdims_in = unzip2((t.val, t.vmap_dim) for t in tracers)
1543        val_outs, bdim_outs = op.vmap(self.dim_size, vals_in, bdims_in, **params)
1544        return [VMapTraceTensor(self, x, bd) for x, bd in list_zip(val_outs, bdim_outs)]
1545
1546    @staticmethod
1547    def move_vmap_dim(x, dim_size, src: int, dst: int):
1548        if src is None:  # unsqueeze and expand
1549            target_shape = list(x.shape)
1550            target_shape.insert(dst, dim_size)
1551            unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))]
1552            x = x.reshape(tuple(unsqueeze_shape))
1553            x = x.expand(tuple(target_shape))
1554            return x
1555        elif src == dst:
1556            return x
1557        else:
1558            perm = [i for i in range(len(x.shape)) if i != src]
1559            perm.insert(dst, src)
1560            return x.permute(tuple(perm))
1561
1562
1563class JVPTraceTensor(TraceTensor):
1564    def __init__(self, trace, primal, tangent):
1565        self._trace = trace
1566        self.primal = primal
1567        self.tangent = tangent
1568
1569    @property
1570    def symval(self):
1571        return get_symval(self.primal)
1572
1573    @property
1574    def val(self):
1575        return self.primal
1576
1577    @property
1578    def dtype(self):
1579        return self.primal.dtype
1580
1581
1582class JVPTrace(Trace):
1583    def pure(self, val):
1584        if isinstance(val, PartialRunTrace):
1585            val = val.pval.const
1586        return JVPTraceTensor(self, val, backend.zeros_like(val))
1587
1588    def run_op(self, op, tracers, params):
1589        primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
1590        primals_out, tangents_out = op.jvp(primals_in, tangents_in, **params)
1591        return [JVPTraceTensor(self, x, t) for x, t in list_zip(primals_out, tangents_out)]
1592
1593
1594class ProgramTraceTensor(TraceTensor):
1595    __slots__ = ["symval"]
1596    symval: SymbolicTensor
1597
1598    def __init__(self, trace, symval):
1599        self._trace = trace
1600        self.symval = symval
1601
1602
1603class ProgramTrace(Trace):
1604    @property
1605    def builder(self):
1606        return self.main.global_data
1607
1608    def new_arg(self, symval) -> ProgramTraceTensor:
1609        symval = SymbolicTensor.like(symval)
1610        tracer = self.builder.new_tracer(self, symval)
1611        self.builder.tracer_to_var[id(tracer)] = Var(symval)
1612
1613        return tracer
1614
1615    def pure(self, val: Any) -> ProgramTraceTensor:
1616        # get_or_make_const_tracer
1617        tracer = self.builder.const_tracers.get(id(val))
1618        if tracer is None:
1619            tracer = self.builder.new_tracer(self, get_symval(val))
1620            self.builder.add_const(tracer, val)
1621        # print(self.builder.const_tracers)
1622        return tracer
1623
1624    def run_op(self, op, tracers, params):
1625        symvals_in = tree_map(lambda x: x.symval, tracers)
1626        symvals_out = op.typecheck(*symvals_in, **params)
1627
1628        out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out]
1629        inputs = [self.builder.getvar(t) for t in tracers]
1630        outvars = [self.builder.add_var(t) for t in out_tracers]
1631
1632        self.builder.add_instruction(Instruction(op, inputs, params, outvars))
1633        return out_tracers
1634
1635
1636class ProgramBuilder:
1637    instructions: List[Instruction]
1638    tracer_to_var: Dict[int, Var]
1639    const_tracers: Dict[int, TraceTensor]
1640    constvals: Dict[Var, Any]
1641    tracers: List[ProgramTraceTensor]
1642
1643    def __init__(self):
1644        self.instructions = []
1645        self.tracer_to_var = {}
1646        self.const_tracers = {}
1647        self.constvals = {}
1648        self.tracers = []
1649
1650    def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor:
1651        tracer = ProgramTraceTensor(trace, symval)
1652        self.tracers += [tracer]
1653        return tracer
1654
1655    def add_instruction(self, instruction: Instruction) -> None:
1656        self.instructions += [instruction]
1657
1658    def add_var(self, tracer: ProgramTraceTensor) -> Var:
1659        assert id(tracer) not in self.tracer_to_var
1660        var = self.tracer_to_var[id(tracer)] = Var(tracer.symval)
1661        return var
1662
1663    def getvar(self, tracer: ProgramTraceTensor) -> Var:
1664        var = self.tracer_to_var.get(id(tracer))
1665        assert var is not None
1666        return var
1667
1668    def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var:
1669        var = self.add_var(tracer)
1670        self.const_tracers[id(val)] = tracer
1671        self.constvals[var] = val
1672        return var
1673
1674    def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]:
1675        constvars, constvals = unzip2(self.constvals.items())
1676        t2v = lambda t: self.tracer_to_var[id(t)]
1677        in_binders = constvars + [t2v(t) for t in in_tracers]
1678        out_vars = [t2v(t) for t in out_tracers]
1679        program = Program(
1680            in_binders,
1681            self.instructions,
1682            out_vars,
1683            len(constvals),
1684            static_args,
1685            name,
1686        )
1687        typecheck_program(program)
1688        program, constvals = self._inline_literals(program, constvals)
1689        typecheck_program(program)
1690        # dblog(program, enable=backend.LOG_PROGRAM)
1691        return program, constvals
1692
1693    def _inline_literals(self, program: Program, consts: List[Any]) -> Tuple[Program, List[Any]]:
1694        const_binders, other_binders = split_list(program.in_binders, len(consts))
1695        scalars = [type(x) in TraceTensor.PYTHON_TYPES and not get_symval(x).shape for x in consts]
1696        new_const_binders, lit_binders = partition_list(scalars, const_binders)
1697        new_consts, lit_vals = partition_list(scalars, consts)
1698        literals = dict(list_zip(lit_binders, list_map(Lit, lit_vals)))
1699        new_outs = [literals.get(x, x) for x in program.outs]
1700        new_instructions = [
1701            Instruction(
1702                instruction.op,
1703                [literals.get(x, x) for x in instruction.inputs],
1704                instruction.params,
1705                instruction.out_binders,
1706            )
1707            for instruction in program.instructions
1708        ]
1709        new_program = Program(
1710            new_const_binders + other_binders,
1711            new_instructions,
1712            new_outs,
1713            len(new_consts),
1714            program.static_args,
1715            program.name,
1716        )
1717        return new_program, tuple(new_consts)
1718
1719    def get_current_scope_info(self):
1720        current_frame = inspect.currentframe()
1721        current_function_name = current_frame.f_code.co_name
1722        current_module_name = inspect.getmodulename(current_frame.f_code.co_filename)
1723        current_class_name = None
1724        for frame_info in inspect.getouterframes(current_frame):
1725            print(frame_info)
1726            frame_locals = frame_info.frame.f_locals
1727            print(frame_locals)
1728            if "self" in frame_locals:
1729                current_class_name = frame_locals["self"].__class__.__name__
1730                break
1731        return {
1732            "Function": current_function_name,
1733            "Module": current_module_name,
1734            "Class": current_class_name,
1735        }
1736
1737
1738class UndefinedPrimal(NamedTuple):
1739    symval: SymbolicTensor
1740
1741    @property
1742    def shape(self):
1743        return self.symval.shape
1744
1745    @property
1746    def dtype(self):
1747        return self.symval.dtype
1748
1749    @property
1750    def device(self):
1751        return self.symval.device
1752
1753    @property
1754    def ndim(self):
1755        return self.symval.ndim
1756
1757    def __repr__(self):
1758        return f"<UndefinedPrimal: symval={self.symval}>"
1759
1760    str_short = __repr__
1761
1762
1763class PartialValue(NamedTuple):
1764    symval: SymbolicTensor
1765    const: Optional[Any]
1766
1767    is_known = property(lambda self: self.const is not None)
1768    is_unknown = property(lambda self: self.const is None)
1769
1770
1771class LambdaBindingDraft(NamedTuple):
1772    pass
1773
1774
1775class ConstDraft(NamedTuple):
1776    val: Any
1777
1778
1779class InstructionDraft(NamedTuple):
1780    prim: Operator
1781    tracers_in: List["PartialRunTraceTensor"]
1782    params: Dict[str, Any]
1783    symvals_out: List[SymbolicTensor]
1784    tracer_refs_out: List[weakref.ReferenceType["PartialRunTraceTensor"]]
1785
1786
1787ProgramDraft = Union[LambdaBindingDraft, ConstDraft, InstructionDraft]
1788
1789
1790class PartialRunTraceTensor(TraceTensor):
1791    def __init__(self, trace, pval, draft):
1792        self._trace = trace
1793        self.pval = pval
1794        self.draft = draft
1795
1796    symval = property(lambda self: self.pval.symval)
1797    val = property(lambda self: self.pval.const)
1798
1799    def full_lower(self):
1800        if self.pval.is_known:
1801            return full_lower(self.pval.const)
1802        return self
1803
1804
1805class PartialRunTrace(Trace):
1806    def new_arg(self, pval: PartialValue) -> Any:
1807        return PartialRunTraceTensor(self, pval, LambdaBindingDraft())
1808
1809    def pure(self, val: Any) -> PartialRunTraceTensor:
1810        return PartialRunTraceTensor(self, make_known_pval(val), None)
1811
1812    def instantiate_const(self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor:
1813        if tracer.pval.is_unknown:
1814            return tracer
1815        else:
1816            pval = make_unknown_pval(SymbolicTensor.like(tracer.symval))
1817            return PartialRunTraceTensor(self, pval, ConstDraft(tracer.pval.const))
1818
1819    def run_op(self, op, tracers, params):
1820        is_knowns = tuple(t.pval.is_known for t in tracers)
1821
1822        if all(is_knowns):
1823            return bind(op, *list_map(full_lower, tracers), **params)
1824        return op.partial_run(self, tracers, **params)
1825
1826
1827trace_stack: List[MainTrace] = []
1828stashed_trace: Optional[MainTrace] = None
1829trace_stack += [MainTrace(0, RunTrace, None)]
1830
1831
1832class UndefBackend:
1833    def __getattr__(self, attr):
1834        raise NotImplementedError("Backend not init yet with slope.core.set_backend(backend)")
1835
1836
1837backend = UndefBackend()
1838
1839
1840def set_backend(name, where="slope.backends"):
1841    global backend
1842    backend = importlib.import_module(f"{where}.{name}").backend
1843    import slope.nn as nn
1844
1845    # backend.register_node(nn.Module, nn.Module.flatten, nn.Module.unflatten, "Module")
1846
1847    dblog(f"slope backend is {backend}", enable=backend.LOG_INIT)
1848
1849
1850def stack_str():
1851    ret = ""
1852    for trace in trace_stack:
1853        ret += f"{trace.level}: {trace.trace_type.__name__}\t{trace.global_data=}\n"
1854    return ret
1855
1856
1857def make_known_pval(val: Any):
1858    return PartialValue(get_symval(val), val)
1859
1860
1861def make_unknown_pval(symval: SymbolicTensor):
1862    return PartialValue(symval, None)
1863
1864
1865def get_symval(x):
1866    if isinstance(x, TraceTensor):
1867        return x.symval
1868    elif type(x) in TraceTensor.PYTHON_TYPES:
1869        return backend.tensor(x)
1870    elif isinstance(x, Tensor):
1871        return x
1872    elif isinstance(x, SymbolicTensor):
1873        return x
1874    else:
1875        raise TypeError(type(x))
1876
1877
1878def tree_flatten(x: Any) -> Any:
1879    def _tree_flatten(x_: Any) -> Tuple[Iterable, Union[TreeDef, Leaf]]:
1880        node_type = None
1881        for k in backend.node_types.keys():
1882            if isinstance(x_, k):
1883                node_type = backend.node_types[k]
1884
1885        if node_type is not None:
1886            node_metadata, children = node_type.flatten(x_)
1887            children_flat, child_trees = unzip2(list_map(_tree_flatten, children))
1888            children_iter = itertools.chain.from_iterable(children_flat)
1889            treedef = TreeDef(node_type, node_metadata, tuple(child_trees))
1890            return children_iter, treedef
1891        else:
1892            return (x_,), Leaf(x_)
1893
1894    children_iter, treedef = _tree_flatten(x)
1895    return tuple(children_iter), treedef
1896
1897
1898def tree_unflatten(treedef: TreeDef, xs: Tuple[Any]) -> Any:
1899    def _tree_unflatten(treedef_: TreeDef, xs_: Iterator) -> Any:
1900        if isinstance(treedef_, Leaf):
1901            dblog(f"    tree leaf found: {xs_}\n", enable=backend.LOG_TREE)
1902            return next(xs_)
1903        else:
1904            dblog(f"    now\n  {treedef_}", enable=backend.LOG_TREE)
1905            children = (_tree_unflatten(t, xs_) for t in treedef_.child_treedefs)
1906            dblog(f"{children=}\n", enable=backend.LOG_TREE)
1907            return treedef_.node_type.unflatten(treedef_.node_metadata, children)
1908
1909    dblog(f"unflattening {treedef}", enable=backend.LOG_TREE)
1910    return _tree_unflatten(treedef, iter(xs))
1911    # with Timing(f"\nTREE:\n{treedef}"):
1912    #     ret = _tree_unflatten(treedef, iter(xs))
1913    # return ret
1914
1915
1916def tree_transpose(
1917    outer_treedef: TreeDef,
1918    inner_treedef: TreeDef,
1919    tree_to_transpose: Any,
1920) -> Any:
1921    flat, treedef = tree_flatten(tree_to_transpose)
1922    inner_size = inner_treedef.num_leaves
1923    outer_size = outer_treedef.num_leaves
1924    if treedef.num_leaves != (inner_size * outer_size):
1925        raise TypeError
1926    iter_flat = iter(flat)
1927    lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
1928    permuted_lol = zip(*lol)
1929    subtrees = map(partial(tree_unflatten, outer_treedef), permuted_lol)
1930    return tree_unflatten(inner_treedef, subtrees)
1931
1932
1933def flatten_fn(f, in_tree, *, has_aux=False):
1934    store = Store()
1935
1936    def flat_fn(*args_flat, **params):
1937        tree_args = tree_unflatten(in_tree, args_flat)
1938        out = f(*tree_args, **params)
1939        if has_aux:
1940            out, aux = out
1941        out_flat, out_tree = tree_flatten(out)
1942        store.set_value(out_tree)
1943        return (out_flat, aux) if has_aux else out_flat
1944
1945    return flat_fn, store
1946
1947
1948def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any:
1949    leaves, treedef = tree_flatten(tree)
1950    if len(rest) == 0:
1951        out_tree_flat = tuple(f(leaf) for leaf in leaves)
1952        out_tree = tree_unflatten(treedef, out_tree_flat)
1953    else:
1954        all_leaves = [leaves]
1955        for t in rest:
1956            t_leaves, t_treedef = tree_flatten(t)
1957            assert t_treedef == treedef
1958            all_leaves += [t_leaves]
1959
1960        out_tree_flat = tuple(f(*xs) for xs in zip(*all_leaves))
1961        out_tree = tree_unflatten(treedef, out_tree_flat)
1962    ret = out_tree
1963    if out_leaf:
1964        ret = (ret, tree_flatten(out_tree_flat[0]))
1965    return ret
1966
1967
1968@contextmanager
1969def new_main_trace(trace_type: Type["Trace"], global_data=None):
1970    global trace_stack
1971    level = len(trace_stack)
1972    main = MainTrace(level, trace_type, global_data)
1973    trace_stack += [main]
1974
1975    try:
1976        yield main
1977    finally:
1978        trace_stack.pop()
1979
1980
1981def bind(op, *args, **params):
1982    top_trace = find_top_trace(args)
1983    tracers = tuple([full_raise(top_trace, arg) for arg in args])
1984    outs = top_trace.run_op(op, tracers, params)
1985    lowered = tuple([full_lower(out) for out in outs])
1986    return lowered
1987
1988
1989def find_top_trace(xs) -> Trace:
1990    arrs = []
1991
1992    def get_arr_from_seq(seq):
1993        nonlocal arrs
1994        for x in seq:
1995            if type(x) in (tuple, list):
1996                get_arr_from_seq(x)
1997            elif isinstance(x, TraceTensor):
1998                arrs += [x]
1999
2000    get_arr_from_seq(xs)
2001    arrs = tuple(arrs)
2002    top_main = max(
2003        (x._trace.main for x in arrs),
2004        default=trace_stack[0],
2005        key=operator_py.attrgetter("level"),
2006    )
2007    if stashed_trace and stashed_trace.level > top_main.level:
2008        top_main = stashed_trace
2009    return top_main.trace_type(top_main)
2010
2011
2012def full_raise(trace: Trace, val: Any) -> TraceTensor:
2013    if not isinstance(val, TraceTensor):
2014        return trace.pure(val)
2015    level = trace.main.level
2016    if val._trace.main is trace.main:
2017        return val
2018    elif val._trace.main.level < level:
2019        return trace.pure(val)
2020    elif val._trace.main.level > level:
2021        raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
2022    else:
2023        raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
2024
2025
2026def full_lower(val: Any):
2027    if isinstance(val, TraceTensor):
2028        return val.full_lower()
2029    elif type(val) in (list, tuple):
2030        return tuple(full_lower(v) for v in val)
2031    else:
2032        return val
2033
2034
2035def typecheck_program(program: Program) -> ProgramType:
2036    env: Set[Var] = set()
2037
2038    for v in program.in_binders:
2039        if v in env:
2040            raise TypeError
2041        env.add(v)
2042
2043    for instruction in program.instructions:
2044        in_types = [typecheck_atom(env, x) for x in instruction.inputs]
2045        out_types = instruction.op.typecheck(*in_types, **instruction.params)
2046        for out_binder, out_type in list_zip(instruction.out_binders, out_types):
2047            if not out_type == out_binder.symval:
2048                raise TypeError
2049        for out_binder in instruction.out_binders:
2050            if out_binder in env:
2051                raise TypeError
2052            env.add(out_binder)
2053
2054    in_types = [v.symval for v in program.in_binders]
2055    out_types = [typecheck_atom(env, x) for x in program.outs]
2056    return ProgramType(tuple(in_types), tuple(out_types))
2057
2058
2059def typecheck_atom(env: Set[Var], x: Atom) -> SymbolicTensor:
2060    if isinstance(x, Var):
2061        if x not in env:
2062            raise TypeError("unbound variable")
2063        return x.symval
2064    elif isinstance(x, Lit):
2065        return get_symval(x.val)
2066    else:
2067        assert False
2068
2069
2070def run_program(program: Program, args: List[Any]) -> List[Any]:
2071    env: Dict[Var, Any] = {}
2072
2073    def read(x: Atom) -> Any:
2074        return env[x] if type(x) is Var else x.val
2075
2076    def write(v: Var, val: Any) -> None:
2077        assert v not in env  # single-assignment
2078        env[v] = val
2079
2080    list_map(write, program.in_binders, args)
2081    for instruction in program.instructions:
2082        in_vals = list_map(read, instruction.inputs)
2083        outs = bind(instruction.op, *in_vals, **instruction.params)
2084        list_map(write, instruction.out_binders, outs)
2085    return list_map(read, program.outs)
2086
2087
2088def program_as_fun(program: Program):
2089    return lambda *args: run_program(program, args)
2090
2091
2092def vmap_flat(f, in_dim, out_dim, dim_size, *args):
2093    if dim_size is None:
2094        dims = set([x.shape[d] for x, d in list_zip(args, in_dim) if d is not None])
2095        assert len(dims) == 1
2096        (dim_size,) = dims
2097    with new_main_trace(VMapTrace, dim_size) as main:
2098        trace = VMapTrace(main)
2099        tracers_in = [VMapTraceTensor(trace, x, dim) if dim is not None else x for x, dim in list_zip(args, in_dim)]
2100        outs = f(*tracers_in)
2101        tracers_out = [full_raise(trace, out) for out in outs]
2102        vals_out, y_vmap_dims = unzip2((t.val, t.vmap_dim) for t in tracers_out)
2103    ret = [VMapTrace.move_vmap_dim(val_out, dim_size, bdim, out_dim) for val_out, bdim, out_dim in zip(vals_out, y_vmap_dims, out_dim)]
2104    return ret
2105
2106
2107def vmap(f, in_dim=0, out_dim=0, dim_size=None):
2108    def batched_f(*args):
2109        nonlocal in_dim, out_dim, dim_size
2110        args_flat, in_tree = tree_flatten(args)
2111        in_dim = (in_dim,) * len(args) if isinstance(in_dim, int) else in_dim
2112        out_dim = (out_dim,) * len(args) if isinstance(out_dim, int) else out_dim
2113        in_dim_flat, in_dim_tree = tree_flatten(in_dim)
2114        out_dim_flat, out_dim_tree = tree_flatten(out_dim)
2115        if not (in_tree == in_dim_tree == out_dim_tree):
2116            raise TypeError(f"\n{in_tree}\n!=\n{in_dim_tree}!=\n{out_dim_tree}")
2117        f_flat, out_tree_store = flatten_fn(f, in_tree)
2118        # if len(args_flat) > len(in_dim_flat):
2119        #     in_dim_flat = (in_dim[0],) * len(args_flat)
2120        outs_flat = vmap_flat(f_flat, in_dim_flat, out_dim_flat, dim_size, *args_flat)
2121        return tree_unflatten(out_tree_store(), outs_flat)
2122
2123    return batched_f
2124
2125
2126def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args):
2127    with new_main_trace(JVPTrace, global_data) as main:
2128        trace = JVPTrace(main)
2129        tracers_in = [JVPTraceTensor(trace, x, t) for x, t in list_zip(primals, tangents)]
2130        jvp_flat_ret = f(*tracers_in, **static_args)
2131        if has_aux:
2132            (outs, aux) = jvp_flat_ret
2133            # aux_ = aux
2134            aux = tree_map(lambda x: x.primal, aux)
2135            # aux = tree_map(lambda x: x.full_lower(), aux)
2136            #
2137        else:
2138            outs = jvp_flat_ret
2139        tracers_out = [full_raise(trace, out) for out in outs]
2140        primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
2141    return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
2142
2143
2144def jvp(f, primals, tangents, *, has_aux=False, global_data=None, **static_args):
2145    primals_flat, in_tree = tree_flatten(primals)
2146    tangents_flat, in_tree2 = tree_flatten(tangents)
2147    for p, t in zip(primals_flat, tangents_flat):
2148        assert p.shape == t.shape, f"{p.shape=} != {t.shape=}"
2149        assert p.dtype == t.dtype, f"{p.dtype=} != {t.dtype=}"
2150        assert p.device == t.device, f"{p.device=} != {t.device=}"
2151    if in_tree != in_tree2:
2152        raise TypeError
2153    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2154    jvp_ret = jvp_flat(
2155        f,
2156        primals_flat,
2157        tangents_flat,
2158        has_aux=has_aux,
2159        global_data=global_data,
2160        **static_args,
2161    )
2162    if has_aux:
2163        (primals_out_flat, tangents_out_flat), aux = jvp_ret
2164    else:
2165        (primals_out_flat, tangents_out_flat) = jvp_ret
2166    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2167    tangents_out = tree_unflatten(out_tree_store(), tangents_out_flat)
2168    return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
2169
2170
2171def jacfwd(f, argnums=0, has_aux=False):
2172    def jvp_fn(x):
2173        return jvp(f, x, (backend.eye(len(x)),), has_aux=has_aux)
2174    return vmap(jvp_fn, in_dim=argnums)
2175
2176def jacrev(f, argnums=0, has_aux=False):
2177    def grad_f(x):
2178        return grad(lambda x: f(x) @ backend.eye(f(x).shape[0]), argnums=argnums, has_aux=has_aux)(x)
2179    return vmap(grad_f)
2180
2181## arange version
2182# def jacrev(f, x):
2183#     def grad_f_i(x, i):
2184#         return grad(lambda x: f(x)[i])(x)
2185#     return vmap(lambda i: grad_i(x, i))(backend.arange(f(x).shape[0]))
2186
2187def hessian(fn, argnums=0, has_aux=False):
2188    return jacrev(jacrev(fn, argnums=argnums, has_aux=has_aux))
2189
2190@contextmanager
2191def stash_trace(main: MainTrace):
2192    global stashed_trace
2193    prev_stashed_trace, stashed_trace = stashed_trace, main
2194    try:
2195        yield
2196    finally:
2197        stashed_trace = prev_stashed_trace
2198
2199
2200@contextmanager
2201def symbolic_run():
2202    global trace_stack
2203    level = len(trace_stack)
2204    main = MainTrace(level, SymbolicRunTrace, global_data=None)
2205    trace_stack += [main]
2206    global stashed_trace
2207    prev_stashed_trace, stashed_trace = stashed_trace, main
2208    try:
2209        yield
2210    finally:
2211        stashed_trace = prev_stashed_trace
2212        trace_stack.pop()
2213
2214
2215@lru_cache_verbose()
2216def make_program(f: Callable, *symvals_in: SymbolicTensor, static_args, name) -> Tuple[Program, List[Any], TreeDef]:
2217    symvals_in, in_tree = tree_flatten(symvals_in)
2218    f, out_tree_store = flatten_fn(f, in_tree)
2219    builder = ProgramBuilder()
2220    with new_main_trace(ProgramTrace, builder) as main:
2221        with stash_trace(main):
2222            trace = ProgramTrace(main)
2223            tracers_in = [trace.new_arg(symval) for symval in symvals_in]
2224            outs = f(*tracers_in, **{k: v for k, v in static_args})
2225            tracers_out = [full_raise(trace, out) if isinstance(out, ProgramTraceTensor) else out.val for out in outs]
2226            program, consts = builder.build(tracers_in, tracers_out, static_args, name)
2227
2228    return program, consts, out_tree_store()
2229
2230
2231@lru_cache_verbose()
2232def vmap_program(program: Program, dim_size, dims_in) -> tuple[Program, list[Any]]:
2233    def unmapped_symval(axis_size: int, batch_dim, symval: SymbolicTensor) -> SymbolicTensor:
2234        if batch_dim is None:
2235            return symval
2236        else:
2237            shape = list(symval.shape)
2238            shape.insert(batch_dim, axis_size)
2239            return symval.like(shape=tuple(shape))
2240
2241    vmap_traceable = vmap(program_as_fun(program), tuple(dims_in))
2242    in_symvals = [unmapped_symval(dim_size, d, v.symval) for v, d in zip(program.in_binders, dims_in)]
2243    program, consts, _ = make_program(
2244        vmap_traceable,
2245        *in_symvals,
2246        static_args=program.static_args,
2247        name=f"vmap_{program.name}",
2248    )
2249    return program, consts
2250
2251
2252@lru_cache_verbose()
2253def jvp_program(program: Program) -> Tuple[Program, List[Any]]:
2254    def jvp_traceable(*primals_and_tangents):
2255        n = len(primals_and_tangents) // 2
2256        primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
2257        return jvp(program_as_fun(program), primals, tangents)
2258
2259    in_symvals = tree_map(lambda v: v.symval, program.in_binders)
2260    new_program, new_consts, _ = make_program(
2261        jvp_traceable,
2262        *in_symvals,
2263        *in_symvals,
2264        static_args=program.static_args,
2265        name=f"{program.name}_jvp",
2266    )
2267    return new_program, new_consts
2268
2269
2270def partial_run_flat(
2271    f: Callable, pvals_in: List["PartialValue"], has_aux, global_data=None
2272) -> Tuple[Program, List["PartialValue"], List[Any]]:
2273    with new_main_trace(PartialRunTrace, global_data) as main:
2274        trace = PartialRunTrace(main)
2275        tracers_in = [trace.new_arg(pval) for pval in pvals_in]
2276        outs = f(*tracers_in)
2277        if has_aux:
2278            outs, aux = outs
2279        tracers_out = [full_raise(trace, out) for out in outs]
2280        pvals_out = [t.pval for t in tracers_out]
2281        unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
2282        unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
2283        program, consts = tracers_to_program(unk_tracers_in, unk_tracers_out)
2284
2285    return (program, pvals_out, consts, aux) if has_aux else (program, pvals_out, consts)
2286
2287
2288def partial_run_program(
2289    program: Program,
2290    in_unknowns: List[bool],
2291    instantiate: Optional[List[bool]] = None,
2292) -> Tuple[Program, Program, List[bool], int]:
2293    env: Dict[Var, bool] = {}
2294    residuals: Set[Var] = set()
2295
2296    def read(x: Atom) -> bool:
2297        return type(x) is Var and env[x]
2298
2299    def write(unk: bool, v: Var) -> None:
2300        env[v] = unk
2301
2302    instructions1, instructions2 = [], []
2303    list_map(write, in_unknowns, program.in_binders)
2304
2305    for instruction in program.instructions:
2306        unks_in = list_map(read, instruction.inputs)
2307        (
2308            instruction1,
2309            instruction2,
2310            unks_out,
2311            res,
2312        ) = instruction.op.partial_run_instruction(unks_in, instruction)
2313        if instruction1 is not None:
2314            instructions1 += [instruction1]
2315        if instruction2 is not None:
2316            instructions2 += [instruction2]
2317        if res is not None:
2318            residuals.update(res)
2319        list_map(write, unks_out, instruction.out_binders)
2320
2321    out_unknowns = list_map(read, program.outs)
2322    if instantiate is not None:
2323        for v, uk, inst in zip(program.outs, out_unknowns, instantiate):
2324            if inst and not uk:
2325                if type(v) is Var:
2326                    residuals.add(v)
2327        out_unknowns = list_map(operator_py.or_, out_unknowns, instantiate)
2328
2329    residuals, num_res = list(residuals), len(residuals)
2330    assert all(type(v) is Var for v in residuals), residuals
2331
2332    ins1, ins2 = partition_list(in_unknowns, program.in_binders)
2333    outs1, outs2 = partition_list(out_unknowns, program.outs)
2334
2335    program1 = Program(
2336        ins1,
2337        instructions1,
2338        outs1 + residuals,
2339        0,
2340        program.static_args,
2341        f"{program.name}_partial1",
2342    )
2343    program2 = Program(
2344        residuals + ins2,
2345        instructions2,
2346        outs2,
2347        0,
2348        program.static_args,
2349        f"{program.name}_partial2",
2350    )
2351    typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2)
2352
2353    return program1, program2, out_unknowns, num_res
2354
2355
2356def typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2):
2357    programty = typecheck_program(program)  # (a1,  a2) -> (b1, b2 )
2358    program1ty = typecheck_program(program1)  #  a1       -> (b1, res)
2359    program2ty = typecheck_program(program2)  # (res, a2) -> b2
2360
2361    a1, a2 = partition_list(in_unknowns, programty.in_types)
2362    b1, b2 = partition_list(out_unknowns, programty.out_types)
2363    b1_, res = split_list(program1ty.out_types, len(b1))
2364    res_, a2_ = split_list(program2ty.in_types, len(res))
2365    b2_ = program2ty.out_types
2366
2367    a1 = tuple(a1)
2368    a2, a2_ = tuple(a2), tuple(a2_)
2369    b1, b1_ = tuple(b1), tuple(b1_)
2370    b2, b2_ = tuple(b2), tuple(b2_)
2371    res, res_ = tuple(res), tuple(res_)
2372
2373    if program1ty.in_types != a1:
2374        raise TypeError
2375    if program2ty.out_types != b2:
2376        raise TypeError
2377    if b1 != b1_:
2378        raise TypeError
2379    if res != res_:
2380        raise TypeError
2381    if a2 != a2_:
2382        raise TypeError
2383    if b2 != b2_:
2384        raise TypeError
2385
2386
2387def linearize_flat(f, *primals_in, has_aux):
2388    pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in]
2389
2390    def f_jvp(*primals_tangents_in):
2391        jvp_ret = jvp(f, *split_half(primals_tangents_in), has_aux=has_aux)
2392        if has_aux:
2393            (primals_out, tangents_out), aux = jvp_ret
2394            return ((*primals_out, *tangents_out), aux)
2395        else:
2396            primals_out, tangents_out = jvp_ret
2397            return (*primals_out, *tangents_out)
2398
2399    partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux)
2400    if has_aux:
2401        program, pvals_out, consts, aux = partial_run_flat_ret
2402    else:
2403        program, pvals_out, consts = partial_run_flat_ret
2404    primal_pvals, _ = split_half(pvals_out)
2405    assert all(pval.is_known for pval in primal_pvals)
2406    primals_out = [pval.const for pval in primal_pvals]
2407    f_lin = lambda *tangents: run_program(program, [*consts, *tangents])
2408    return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
2409
2410
2411def linearize(f, *primals_in, has_aux=False):
2412    primals_in_flat, in_tree = tree_flatten(primals_in)
2413    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2414    linearize_flat_ret = linearize_flat(f, *primals_in_flat, has_aux=has_aux)
2415    if has_aux:
2416        primals_out_flat, f_lin_flat, aux = linearize_flat_ret
2417    else:
2418        primals_out_flat, f_lin_flat = linearize_flat_ret
2419
2420    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2421
2422    def f_lin(*tangents_in):
2423        tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
2424        if in_tree != in_tree2:
2425            raise TypeError
2426        tangents_out_flat = f_lin_flat(*tangents_in_flat)
2427        return tree_unflatten(out_tree_store(), tangents_out_flat)
2428
2429    return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
2430
2431
2432def tracers_to_program(
2433    tracers_in: List["PartialRunTraceTensor"],
2434    tracers_out: List["PartialRunTraceTensor"],
2435):
2436    def tracer_parents(t: PartialRunTraceTensor) -> List[PartialRunTraceTensor]:
2437        return t.draft.tracers_in if isinstance(t.draft, InstructionDraft) else []
2438
2439    def draft_to_instruction(tracer_to_var: Dict[int, Var], draft: InstructionDraft) -> Instruction:
2440        inputs = [tracer_to_var[id(t)] for t in draft.tracers_in]
2441        out_binders = [Var(symval) for symval in draft.symvals_out]
2442        for t_ref, var in list_zip(draft.tracer_refs_out, out_binders):
2443            if t_ref() is not None:
2444                tracer_to_var[id(t_ref())] = var
2445        return Instruction(draft.prim, inputs, draft.params, out_binders)
2446
2447    tracer_to_var: Dict[int, Var] = {id(t): Var(SymbolicTensor.like(t.symval)) for t in tracers_in}
2448    constvar_to_val: Dict[int, Any] = {}
2449    constid_to_var: Dict[int, Var] = {}
2450    processed_instructions: Set[int] = set()
2451    instructions: List[Instruction] = []
2452    for t in toposort(tracers_out, tracer_parents):
2453        if isinstance(t.draft, LambdaBindingDraft):
2454            assert id(t) in set(list_map(id, tracers_in))
2455        elif isinstance(t.draft, ConstDraft):
2456            val = t.draft.val
2457            var = constid_to_var.get(id(val))
2458            if var is None:
2459                symval = SymbolicTensor.like(get_symval(val))
2460                var = constid_to_var[id(val)] = Var(symval)
2461                constvar_to_val[var] = val
2462            tracer_to_var[id(t)] = var
2463        elif isinstance(t.draft, InstructionDraft):
2464            if id(t.draft) not in processed_instructions:
2465                instructions += [draft_to_instruction(tracer_to_var, t.draft)]
2466                processed_instructions.add(id(t.draft))
2467        else:
2468            raise TypeError(t.draft)
2469
2470    constvars, constvals = unzip2(constvar_to_val.items())
2471    in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
2472    # out_vars = [tracer_to_var[id(t)] for t in tracers_out if id(t) in tracer_to_var]
2473    out_vars = [tracer_to_var[id(t)] for t in tracers_out]
2474    program = Program(tuple(in_binders), tuple(instructions), tuple(out_vars))
2475    typecheck_program(program)
2476    return program, constvals
2477
2478
2479def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
2480    def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
2481        seen = set()
2482        for node in nodes:
2483            assert all(id(parent) in seen for parent in parents(node))
2484            seen.add(id(node))
2485
2486    def remove_duplicates(lst):
2487        seen = set()
2488        return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
2489
2490    if not out_nodes:
2491        return []
2492    out_nodes = remove_duplicates(out_nodes)
2493
2494    child_counts = {}
2495    stack = list(out_nodes)
2496    while stack:
2497        node = stack.pop()
2498        if id(node) in child_counts:
2499            child_counts[id(node)] += 1
2500        else:
2501            child_counts[id(node)] = 1
2502            stack.extend(parents(node))
2503    for node in out_nodes:
2504        child_counts[id(node)] -= 1
2505
2506    sorted_nodes = []
2507    childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
2508    while childless_nodes:
2509        node = childless_nodes.pop()
2510        sorted_nodes += [node]
2511        for parent in parents(node):
2512            if child_counts[id(parent)] == 1:
2513                childless_nodes += [parent]
2514            else:
2515                child_counts[id(parent)] -= 1
2516
2517    sorted_nodes = sorted_nodes[::-1]
2518    check_toposort(sorted_nodes, parents)
2519    return sorted_nodes
2520
2521
2522def vjp_flat(f, *primals_in, has_aux=False, **static_args):
2523    pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in]
2524    primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
2525    del primal_pvals_in
2526
2527    def f_jvp(*primals_tangents_in):
2528        jvp_ret = jvp(
2529            f,
2530            *split_half(primals_tangents_in),
2531            has_aux=has_aux,
2532            global_data="vjp",
2533            **static_args,
2534        )
2535        if has_aux:
2536            ((primals_out, tangents_out), aux) = jvp_ret
2537        else:
2538            (primals_out, tangents_out) = jvp_ret
2539        return ([*primals_out, *tangents_out], aux) if has_aux else [*primals_out, *tangents_out]
2540
2541    partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux, "vjp")
2542    if has_aux:
2543        program, pvals_out, consts, aux = partial_run_flat_ret
2544    else:
2545        program, pvals_out, consts = partial_run_flat_ret
2546
2547    primal_pvals, tangent_pvals = split_half(pvals_out)
2548    del tangent_pvals
2549    assert all(pval.is_known for pval in primal_pvals)
2550    primals_out_flat = [pval.const for pval in primal_pvals]
2551    transpose_inputs = consts + [UndefinedPrimal(t.symval) for t in tangent_pvals_in]
2552
2553    def f_vjp_flat(*cotangents):
2554        # return backward_pass(program, transpose_inputs, cotangents)
2555        undef_primals = tuple(isinstance(x, UndefinedPrimal) for x in transpose_inputs)
2556        transposed_program, new_consts = transpose_program(program, undef_primals)
2557        residuals, _ = partition_list(undef_primals, transpose_inputs)
2558        outs = run_program(transposed_program, (*new_consts, *residuals, *cotangents))
2559        return outs
2560
2561    return (primals_out_flat, f_vjp_flat, aux) if has_aux else (primals_out_flat, f_vjp_flat)
2562
2563
2564def vjp(f, *primals_in, has_aux=False, **static_args):
2565    primals_in_flat, in_tree = tree_flatten(primals_in)
2566    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2567    vjp_ret = vjp_flat(f, *primals_in_flat, has_aux=has_aux, **static_args)
2568    if has_aux:
2569        primals_out_flat, f_vjp_flat, aux = vjp_ret
2570    else:
2571        primals_out_flat, f_vjp_flat = vjp_ret
2572    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2573
2574    def f_vjp(*cotangents_out):
2575        cotangents_out_flat, _ = tree_flatten(cotangents_out)
2576        cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
2577
2578        return tree_unflatten(in_tree, cotangents_in_flat)
2579
2580    return (primals_out, f_vjp, aux) if has_aux else (primals_out, f_vjp)
2581
2582
2583NullCotangent = None
2584
2585
2586def backward_pass(program: Program, args: List[Any], cotangents: List[Any]) -> List[Any]:
2587    primal_env: Dict[Var, Any] = {}
2588    ct_env: Dict[Var, Any] = {}
2589
2590    def read_primal(x: Atom) -> Any:
2591        return primal_env.get(x, UndefinedPrimal(x.symval)) if type(x) is Var else x.val
2592
2593    def write_primal(v: Var, val: Any) -> None:
2594        if type(val) is not UndefinedPrimal:
2595            primal_env[v] = val
2596
2597    def read_cotangent(v: Var) -> Any:
2598        return ct_env.pop(v, backend.zeros(v.symval.shape, v.symval.dtype))
2599
2600    def write_cotangent(x: Atom, ct: Any):
2601        if type(x) is Var and ct is not NullCotangent:
2602            ct_env[x] = (ct_env[x] + ct) if x in ct_env else ct
2603
2604    list_map(write_primal, program.in_binders, args)
2605    list_map(write_cotangent, program.outs, cotangents)
2606    for instruction in program.instructions[::-1]:
2607        primals_in = list_map(read_primal, instruction.inputs)
2608        cotangents_in = list_map(read_cotangent, instruction.out_binders)
2609        inp, params = primals_in, instruction.params
2610        cotangents_out = instruction.op.T(cotangents_in, *inp, **params)
2611        list_map(write_cotangent, instruction.inputs, cotangents_out)
2612
2613    ret = [read_cotangent(v) for v, x in list_zip(program.in_binders, args) if isinstance(x, UndefinedPrimal)]
2614    return ret
2615
2616
2617@lru_cache_verbose()
2618def transpose_program(program: Program, undef_primals: tuple[bool, ...]) -> tuple[Program, list[Any]]:
2619    symvals_in, symvals_out = typecheck_program(program)
2620    traceable = partial(backward_pass, program)
2621    ()
2622    args = [UndefinedPrimal(a) if u else a for a, u in zip(symvals_in, undef_primals)]
2623    trans_program, consts, _ = make_program(
2624        traceable,
2625        tuple(args),
2626        tuple(symvals_out),
2627        static_args=program.static_args,
2628        name=f"{program.name}_T",
2629    )
2630    typecheck_program(trans_program)
2631
2632    return trans_program, consts
2633
2634
2635def grad(f, argnums=(0,), argnames="", has_aux=False, return_value=False):
2636    f, rejit = (f, False) if not isinstance(f, jit) else (f.f, True)
2637    if isinstance(argnums, int):
2638        argnums = (argnums,)
2639
2640    def gfn(x, *xs, **static_args):
2641        vjp_ret = vjp(f, x, *xs, has_aux=has_aux, **static_args)
2642        if has_aux:
2643            y, f_vjp, aux = vjp_ret
2644        else:
2645            y, f_vjp = vjp_ret
2646        if np.shape(y) != ():
2647            raise TypeError("grad output must be 0-dim scalar with shape ()")
2648        gL_xs = f_vjp(backend.ones(()))
2649        gL_xs = tuple(gL_xs[i] for i in argnums) if len(argnums) > 1 else gL_xs[argnums[0]]
2650        if return_value:
2651            return ((y, aux), gL_xs) if has_aux else (y, gL_xs)
2652        else:
2653            return (gL_xs, aux) if has_aux else gL_xs
2654
2655    return jit(gfn) if rejit else gfn
2656
2657
2658def value_and_grad(f, argnums=(0,), argnames="", has_aux=False):
2659    return grad(
2660        f,
2661        argnums=argnums,
2662        argnames=argnames,
2663        has_aux=has_aux,
2664        return_value=True,
2665    )
2666
2667
2668def jit_partial_run(trace, tracers, *, program):
2669    in_unknowns = [not t.pval.is_known for t in tracers]
2670    program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns)
2671    known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
2672    known_vals = [t.pval.const for t in known_tracers]
2673    outs1_res = backend.jit_op(*known_vals, program=program)
2674    outs1, res = split_list(outs1_res, len(program1.outs) - num_res)
2675    res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
2676    outs2 = [PartialRunTraceTensor(trace, PartialValue.unknown(v.symval), None) for v in program2.outs]
2677    draft = InstructionDraft(
2678        backend.jit_op,
2679        res_tracers + unknown_tracers,
2680        dict(program=program2),
2681        [v.symval for v in program2.outs],
2682        map(weakref.ref, outs2),
2683    )
2684    for t in outs2:
2685        t.draft = draft
2686    return merge_lists(out_unknowns, outs1, outs2)
2687
2688
2689class jit:
2690    def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None):
2691        if isinstance(static_argnames, str):
2692            static_argnames = tuple(static_argnames.split(" "))
2693        assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames)
2694        self.f = f
2695        self.name = name if name is not None else self.f.__name__
2696        self.static_argnames = static_argnames
2697        self.dynamic_axes = dynamic_axes
2698
2699    @classmethod
2700    def with_options(cls, **kwargs):
2701        return partial(cls, **kwargs)
2702
2703    @classmethod
2704    def get_jit_name(cls, args, static_args, prefix="jit", short=False):
2705        name = f"{prefix}_"
2706        if short:
2707            static_args_tup = tuple(static_args.items())
2708            ids = repr(hash((prefix, args, static_args_tup)))[-4:]
2709            name = f"{prefix}_{ids}"
2710        else:
2711            for a in args:
2712                name += f"shape_{a.shape}_dtype_{a.dtype.name}_"
2713            for k, v in static_args.items():
2714                name += f"{k}_{v}_"
2715            name = name.replace("(", "L")
2716            name = name.replace(")", "R")
2717            name = name.replace(",", "C")
2718            name = name.replace(" ", "")
2719            name = name.replace(".", "D")
2720
2721        return name
2722
2723    def get_program(self, *args, **static_args):
2724        sig = inspect.signature(self.f)
2725        if all("*" not in repr(v) for v in sig.parameters.values()):
2726            args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames]
2727            static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames]
2728
2729            if args:
2730                if len(args) > len(args_strs):
2731                    assert static_args_strs
2732                    args, rest = args[: len(args_strs)], args[len(args_strs) :]
2733                    new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args}
2734                    static_args = {**new_static_args, **static_args}
2735            else:
2736                args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)])
2737
2738        symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args)
2739        static_args = tuple(static_args.items())
2740        if self.name is None:
2741            self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}"
2742        program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name)
2743        return program, consts, out_tree
2744
2745    def __call__(self, *args, **static_args):
2746        program, consts, out_tree = self.get_program(*args, **static_args)
2747        args, in_tree = tree_flatten(args)
2748        outs = bind(backend.jit_op, *consts, *args, program=program)
2749        return tree_unflatten(out_tree, outs)
2750
2751    def lower(self, *args, **static_args):
2752        program, consts, out_tree = self.get_program(*args, **static_args)
2753        args, in_tree = tree_flatten(args)
2754        hashed_program = Hashed(program)
2755        num_consts = program.num_consts
2756        consts, args = args[:num_consts], args[num_consts:]
2757        hashed_consts = tuple(map(Hashed, consts))
2758        jit_output = backend.jit_program(hashed_program, hashed_consts)
2759        return jit_output
2760
2761    def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs):
2762        if isinstance(args, Tensor):
2763            args, static_args = (args,), dict()
2764        elif not isinstance(args[-1], dict):
2765            assert all(isinstance(a, Tensor) for a in args)
2766            static_args = dict()
2767        else:
2768            args, static_args = args[:-1], args[-1]
2769        assert isinstance(args, (tuple, list)) and isinstance(static_args, dict)
2770        jit_output = self.lower(*args, **static_args)
2771        backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)
class Timing(contextlib.ContextDecorator):
46class Timing(ContextDecorator):
47    def __init__(self, prefix="", on_exit=None, enabled=True):
48        self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
49
50    def __enter__(self):
51        self.st = time.perf_counter_ns()
52
53    def __exit__(self, *exc):
54        self.et = time.perf_counter_ns() - self.st
55        if self.enabled:
56            print(f"{self.prefix}{self.et*1e-6:6.2f} ms" + (self.on_exit(self.et) if self.on_exit else ""))

A base class or mixin that enables context managers to work as decorators.

Timing(prefix='', on_exit=None, enabled=True)
47    def __init__(self, prefix="", on_exit=None, enabled=True):
48        self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def colored(st, color: Optional[str], background=False):
59def colored(st, color: Optional[str], background=False):
60    return (
61        f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m"
62        if color is not None
63        else st
64    )  # replace the termcolor library with one line  # noqa: E501
class Profiling(contextlib.ContextDecorator):
71class Profiling(ContextDecorator):
72    def __init__(self, enabled=True, sort="cumtime", frac=0.2, fn=None, ts=1):
73        self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3 / ts
74
75    def __enter__(self):
76        self.pr = cProfile.Profile()
77        if self.enabled:
78            self.pr.enable()
79
80    def __exit__(self, *exc):
81        if self.enabled:
82            self.pr.disable()
83            if self.fn:
84                self.pr.dump_stats(self.fn)
85            stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
86            for fcn in stats.fcn_list[0 : int(len(stats.fcn_list) * self.frac)]:  # type: ignore[attr-defined]
87                (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn]  # type: ignore[attr-defined]
88                scallers = sorted(callers.items(), key=lambda x: -x[1][2])
89                print(
90                    f"n:{num_calls:8d}  tm:{tottime*self.time_scale:7.2f}ms  tot:{cumtime*self.time_scale:7.2f}ms",
91                    colored(_format_fcn(fcn), "yellow") + " " * (50 - len(_format_fcn(fcn))),
92                    colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else "",
93                )

A base class or mixin that enables context managers to work as decorators.

Profiling(enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1)
72    def __init__(self, enabled=True, sort="cumtime", frac=0.2, fn=None, ts=1):
73        self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3 / ts
def dblog(*msg, enable=True):
96def dblog(*msg, enable=True):
97    if enable:
98        print(*msg)
def unzip2(pairs) -> Tuple[List[Any], List[Any]]:
101def unzip2(pairs) -> Tuple[List[Any], List[Any]]:
102    lst1, lst2 = [], []
103    for i1, i2 in pairs:
104        lst1 += [i1]
105        lst2 += [i2]
106    return lst1, lst2
def list_map(f: Callable, *xs: Iterable) -> List[Any]:
109def list_map(f: Callable, *xs: Iterable) -> List[Any]:
110    return list(map(f, *xs))
def list_zip(*args: List[Any]) -> List[Any]:
113def list_zip(*args: List[Any]) -> List[Any]:
114    fst, *rest = args = list_map(list, args)
115    n = len(fst)
116    for arg in rest:
117        assert len(arg) == n
118    return list(zip(*args))
def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
121def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:
122    assert not len(lst) % 2
123    return split_list(lst, len(lst) // 2)
def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
126def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:
127    l1, l2 = iter(l1), iter(l2)
128    out = [next(l2) if b else next(l1) for b in which]
129    assert next(l1, None) is next(l2, None) is None
130    return out
def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
138def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:
139    assert 0 <= n <= len(lst)
140    return lst[:n], lst[n:]
def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
143def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:
144    assert len(bs) == len(l)
145    lists = lst1, lst2 = [], []
146    for b, x in zip(bs, l):
147        lists[b].append(x)
148    return lst1, lst2
def lru_cache_verbose( maxsize: int = 100, typed: bool = False, tb_start: int = -12, tb_end: int = -7):
151def lru_cache_verbose(
152    maxsize: int = 100,
153    typed: bool = False,
154    tb_start: int = -12,
155    tb_end: int = -7,
156):
157    def decorator(fn: Callable):
158        @lru_cache(maxsize=maxsize, typed=typed)
159        def wrapper(*args, **kwargs) -> Callable:
160            return fn(*args, **kwargs)
161
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
174
175        decorated_function.cache_info = wrapper.cache_info
176        decorated_function.fn = fn
177        return decorated_function
178
179    return decorator
def cuda_is_available():
182def cuda_is_available():
183    try:
184        import subprocess
185        import platform
186
187        cmd = f"nvidia-smi{'.exe' if platform.system == 'Windows' else ''}"
188        result = subprocess.run([cmd], stdout=subprocess.PIPE)
189        output = result.stdout.decode("utf-8")
190        return True if "NVIDIA-SMI" in output else False
191    except FileNotFoundError:
192        return False
class Hashed:
195class Hashed:
196    val: Any
197
198    def __init__(self, val):
199        self.val = val
200
201    def __hash__(self) -> int:
202        return hash((self.val,))
203
204    def __eq__(self, other):
205        if isinstance(other, Hashed):
206            if isinstance(self.val, Tensor) and isinstance(other.val, Tensor):
207                # because Tensor.__eq__ already for Tensor.equal
208                return id(self.val) == id(other.val)
209            return self.val == other.val
210        return False
211
212    def __repr__(self):
213        return f"Hashed: {repr(self.val)}"
Hashed(val)
198    def __init__(self, val):
199        self.val = val
val: Any
class DType(typing.NamedTuple):
221class DType(NamedTuple):
222    priority: int
223    itemsize: int
224    name: str
225    mlir: str
226    numpy: type
227
228    @property
229    def format_code(self):
230        return f"slope.{self.name}"
231
232    def __repr__(self):
233        return f"<DType: {self.name}>"

DType(priority, itemsize, name, mlir, numpy)

DType(priority: int, itemsize: int, name: str, mlir: str, numpy: type)

Create new instance of DType(priority, itemsize, name, mlir, numpy)

priority: int

Alias for field number 0

itemsize: int

Alias for field number 1

name: str

Alias for field number 2

mlir: str

Alias for field number 3

numpy: type

Alias for field number 4

format_code
228    @property
229    def format_code(self):
230        return f"slope.{self.name}"
Inherited Members
builtins.tuple
index
count
class dtypes:
236class dtypes:
237    float32: Final[DType] = DType(4, 4, "float32", "f32", np.float32)
238    uint8: Final[DType] = DType(0, 1, "uint8", "u8", np.uint8)
239    int8: Final[DType] = DType(0, 1, "int8", "i8", np.int8)
240    bool: Final[DType] = DType(0, 1, "bool", "i1", bool)
241    int32: Final[DType] = DType(1, 4, "int32", "i32", np.int32)
242    int64: Final[DType] = DType(2, 8, "int64", "i64", np.int64)
243    uint64: Final[DType] = DType(2, 8, "uint64", "ui64", np.uint64)
244    float16: Final[DType] = DType(0, 2, "float16", "f16", np.float16)
245    half = float16
246    # bfloat16: Final[DType] = DType(0, 2, "bfloat16", "bf16", np.float16)
247
248    all_dtypes = (bool, float16, float32, int8, int32, int64, uint8, uint64)
249    name_dtype_map = {k.name: k for k in all_dtypes}
250    name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()}
251    mlir_dtype_map = {k.mlir: k for k in all_dtypes}
252    mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()}
253
254    @classmethod
255    def is_int(cls, dtype):
256        return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64)
257
258    @classmethod
259    def is_float(cls, dtype):
260        return dtype in (cls.float16, cls.bfloat16, cls.float32)
float32: Final[DType] = <DType: float32>
uint8: Final[DType] = <DType: uint8>
int8: Final[DType] = <DType: int8>
bool: Final[DType] = <DType: bool>
int32: Final[DType] = <DType: int32>
int64: Final[DType] = <DType: int64>
uint64: Final[DType] = <DType: uint64>
float16: Final[DType] = <DType: float16>
half = <DType: float16>
all_dtypes = (<DType: bool>, <DType: float16>, <DType: float32>, <DType: int8>, <DType: int32>, <DType: int64>, <DType: uint8>, <DType: uint64>)
name_dtype_map = {'bool': <DType: bool>, 'float16': <DType: float16>, 'float32': <DType: float32>, 'int8': <DType: int8>, 'int32': <DType: int32>, 'int64': <DType: int64>, 'uint8': <DType: uint8>, 'uint64': <DType: uint64>}
name_dtype_map_inv = {<DType: bool>: 'bool', <DType: float16>: 'float16', <DType: float32>: 'float32', <DType: int8>: 'int8', <DType: int32>: 'int32', <DType: int64>: 'int64', <DType: uint8>: 'uint8', <DType: uint64>: 'uint64'}
mlir_dtype_map = {'i1': <DType: bool>, 'f16': <DType: float16>, 'f32': <DType: float32>, 'i8': <DType: int8>, 'i32': <DType: int32>, 'i64': <DType: int64>, 'u8': <DType: uint8>, 'ui64': <DType: uint64>}
mlir_dtype_map_inv = {<DType: bool>: 'i1', <DType: float16>: 'f16', <DType: float32>: 'f32', <DType: int8>: 'i8', <DType: int32>: 'i32', <DType: int64>: 'i64', <DType: uint8>: 'u8', <DType: uint64>: 'ui64'}
@classmethod
def is_int(cls, dtype):
254    @classmethod
255    def is_int(cls, dtype):
256        return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64)
@classmethod
def is_float(cls, dtype):
258    @classmethod
259    def is_float(cls, dtype):
260        return dtype in (cls.float16, cls.bfloat16, cls.float32)
class Device(typing.NamedTuple):
263class Device(NamedTuple):
264    name: str
265    idx: int
266
267    @property
268    def format_code(self):
269        return f"'{self.name}:{self.idx}'"
270
271    def __repr__(self):
272        return f"<Device: {self.format_code}>"

Device(name, idx)

Device(name: str, idx: int)

Create new instance of Device(name, idx)

name: str

Alias for field number 0

idx: int

Alias for field number 1

format_code
267    @property
268    def format_code(self):
269        return f"'{self.name}:{self.idx}'"
Inherited Members
builtins.tuple
index
count
class devices:
275class devices:
276    cpu: Final[Device] = Device("cpu", 0)
277    metal: Final[Device] = Device("metal", 0)
278    cuda0: Final[Device] = Device("cuda", 0)
279    # TODO: programmatically define this class attrs to support other setup
280    cuda = cuda0
281    all_devices = (cpu, metal, cuda0)
282    name_idx_device_map = {f"{k.name}:{k.idx}": k for k in all_devices}
283    name_idx_device_map_inv = {v: k for k, v in name_idx_device_map.items()}
cpu: Final[Device] = <Device: 'cpu:0'>
metal: Final[Device] = <Device: 'metal:0'>
cuda0: Final[Device] = <Device: 'cuda:0'>
cuda = <Device: 'cuda:0'>
all_devices = (<Device: 'cpu:0'>, <Device: 'metal:0'>, <Device: 'cuda:0'>)
name_idx_device_map = {'cpu:0': <Device: 'cpu:0'>, 'metal:0': <Device: 'metal:0'>, 'cuda:0': <Device: 'cuda:0'>}
name_idx_device_map_inv = {<Device: 'cpu:0'>: 'cpu:0', <Device: 'metal:0'>: 'metal:0', <Device: 'cuda:0'>: 'cuda:0'}
class TensorBuffer:
286class TensorBuffer:
287    def __init__(self, val):
288        self.val = val
TensorBuffer(val)
287    def __init__(self, val):
288        self.val = val
val
class Tensor:
291class Tensor:
292    def __init__(self, val: TensorBuffer):
293        assert isinstance(val, TensorBuffer)
294        self.buf = val
295
296    @property
297    def symval(self):
298        return SymbolicTensor.like(self)
299
300    @property
301    def default_dtype(self):
302        return backend.default_dtype
303
304    def is_int(self) -> bool:
305        return self.dtype in (
306            dtypes.int8,
307            dtypes.uint8,
308            dtypes.uint64,
309            dtypes.int32,
310            dtypes.int64,
311        )
312
313    def is_float(self) -> bool:
314        return self.dtype in (dtypes.float16, dtypes.float32)
315
316    def is_unsigned(self) -> bool:
317        return self.dtype is dtypes.uint8
318
319    def to_bool(self):
320        return self.cast(dtypes.bool)
321
322    def short(self):
323        return self.cast(dtypes.int8)
324
325    def int(self):
326        return self.cast(dtypes.int32)
327
328    def long(self):
329        return self.cast(dtypes.int64)
330
331    def half(self):
332        return self.cast(dtypes.float16)
333
334    def float(self):
335        return self.cast(dtypes.float32)
336
337    def __getattr__(self, attr):
338        if attr in vars(backend.operator_set).keys():
339            op = getattr(backend.operator_set, attr)
340            return partial(op, self)
341        elif attr in vars(backend.procedure_set).keys():
342            procedure = getattr(backend.procedure_set, attr)
343            assert not isinstance(procedure, classmethod), f"use {attr} instead of self.{attr}"
344            return partial(procedure, self)
345        else:
346            return self.__getattribute__(attr)
347
348    def __getitem__(self, idx):
349        return self.getitem(idx)
350
351    def __setitem__(self, idx, item):
352        raise NotImplementedError
353
354    def str_short(self):
355        return f"<Tensor: shape={self.shape}, dtype={self.dtype}>"
356
357    __neg__ = lambda self: self.neg()
358    __add__ = lambda self, other: self.add(other)
359    __radd__ = lambda self, other: self.add(other)
360    __sub__ = lambda self, other: self.sub(other)
361    __rsub__ = lambda self, other: self.sub.func(other, self)
362    __mul__ = lambda self, other: self.mul(other)
363    __rmul__ = lambda self, other: self.mul(other)
364    __div__ = lambda self, other: self.div(other)
365    __rdiv__ = lambda self, other: self.div.func(other, self)
366    __truediv__ = __div__
367    __truerdiv__ = __rdiv__
368    __pow__ = lambda self, other: self.pow(other)
369    __rpow__ = lambda self, other: self.pow.func(other, self)
370    __matmul__ = lambda self, other: self.matmul(other)
371    __rmatmul__ = lambda self, other: self.matmul.func(other, self)
372    __invert__ = lambda self: self.invert()
373    __eq__ = lambda self, other: self.equal(other)
374    __ne__ = lambda self, other: self.not_equal(other)
375    __ge__ = lambda self, other: self.greater_equal(other)
376    __le__ = lambda self, other: self.less_equal(other)
377    __gt__ = lambda self, other: self.greater(other)
378    __lt__ = lambda self, other: self.less(other)
379
380    def __hash__(self):
381        return id(self.val)
382
383    val = property(lambda self: self.buf.val)
384
385    def size(self, i):
386        return self.shape[i]
387
388    @property
389    def dtype(self):
390        return backend.dtype_of(self)
391
392    @property
393    def device(self):
394        return backend.device_of(self)
395
396    def numpy(self, memmap=False):
397        return backend.numpy_of(self, memmap)
398
399    @property
400    def shape(self):
401        return backend.shape_of(self)
402
403    @property
404    def ndim(self):
405        return len(self.shape)
406
407    def numel(self):
408        return math.prod(self.shape)
409
410    def element_size(self):
411        return self.dtype.itemsize
412
413    def nbytes(self):
414        return self.numel() * self.element_size()
415
416    def __repr__(self):
417        return f"<Tensor: val=\n{self.numpy()}\nshape={self.shape}, dtype={self.dtype.name}, device={self.device.format_code}>"
Tensor(val: TensorBuffer)
292    def __init__(self, val: TensorBuffer):
293        assert isinstance(val, TensorBuffer)
294        self.buf = val
buf
symval
296    @property
297    def symval(self):
298        return SymbolicTensor.like(self)
default_dtype
300    @property
301    def default_dtype(self):
302        return backend.default_dtype
def is_int(self) -> bool:
304    def is_int(self) -> bool:
305        return self.dtype in (
306            dtypes.int8,
307            dtypes.uint8,
308            dtypes.uint64,
309            dtypes.int32,
310            dtypes.int64,
311        )
def is_float(self) -> bool:
313    def is_float(self) -> bool:
314        return self.dtype in (dtypes.float16, dtypes.float32)
def is_unsigned(self) -> bool:
316    def is_unsigned(self) -> bool:
317        return self.dtype is dtypes.uint8
def to_bool(self):
319    def to_bool(self):
320        return self.cast(dtypes.bool)
def short(self):
322    def short(self):
323        return self.cast(dtypes.int8)
def int(self):
325    def int(self):
326        return self.cast(dtypes.int32)
def long(self):
328    def long(self):
329        return self.cast(dtypes.int64)
def half(self):
331    def half(self):
332        return self.cast(dtypes.float16)
def float(self):
334    def float(self):
335        return self.cast(dtypes.float32)
def str_short(self):
354    def str_short(self):
355        return f"<Tensor: shape={self.shape}, dtype={self.dtype}>"
val
383    val = property(lambda self: self.buf.val)
def size(self, i):
385    def size(self, i):
386        return self.shape[i]
dtype
388    @property
389    def dtype(self):
390        return backend.dtype_of(self)
device
392    @property
393    def device(self):
394        return backend.device_of(self)
def numpy(self, memmap=False):
396    def numpy(self, memmap=False):
397        return backend.numpy_of(self, memmap)
shape
399    @property
400    def shape(self):
401        return backend.shape_of(self)
ndim
403    @property
404    def ndim(self):
405        return len(self.shape)
def numel(self):
407    def numel(self):
408        return math.prod(self.shape)
def element_size(self):
410    def element_size(self):
411        return self.dtype.itemsize
def nbytes(self):
413    def nbytes(self):
414        return self.numel() * self.element_size()
class SymbolicTensor(Tensor):
420class SymbolicTensor(Tensor):
421    def __init__(self, shape, dtype, device):
422        assert isinstance(dtype, DType)
423        self._shape = tuple(int(i) for i in shape)
424        self._dtype = dtype
425        self._device = device
426
427    @property
428    def symval(self):
429        return self
430
431    @property
432    def val(self):
433        raise RuntimeError(f"SymbolicTensor actually has no val, from {trace_stack[-1]=}, ")
434
435    @property
436    def shape(self):
437        return self._shape
438
439    @property
440    def dtype(self):
441        return self._dtype
442
443    @property
444    def device(self):
445        return self._device
446
447    def like(self, **overrides):
448        shape = overrides.get("shape", self.shape)
449        dtype = overrides.get("dtype", self.dtype)
450        device = overrides.get("device", self.device)
451        return SymbolicTensor(shape, dtype, device)
452
453    def str_short(self):
454        return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]'
455
456    def __hash__(self):
457        return hash((self.shape, self.dtype))
458
459    def __eq__(self, other):
460        if type(self) != type(other):
461            return False
462        return (self.shape == other.shape) and (self.dtype == other.dtype)
463
464    def __repr__(self):
465        return f"<SymbolicTensor: shape={self.shape}, dtype={self.dtype.name}, device={self.device}>"
SymbolicTensor(shape, dtype, device)
421    def __init__(self, shape, dtype, device):
422        assert isinstance(dtype, DType)
423        self._shape = tuple(int(i) for i in shape)
424        self._dtype = dtype
425        self._device = device
symval
427    @property
428    def symval(self):
429        return self
val
431    @property
432    def val(self):
433        raise RuntimeError(f"SymbolicTensor actually has no val, from {trace_stack[-1]=}, ")
shape
435    @property
436    def shape(self):
437        return self._shape
dtype
439    @property
440    def dtype(self):
441        return self._dtype
device
443    @property
444    def device(self):
445        return self._device
def like(self, **overrides):
447    def like(self, **overrides):
448        shape = overrides.get("shape", self.shape)
449        dtype = overrides.get("dtype", self.dtype)
450        device = overrides.get("device", self.device)
451        return SymbolicTensor(shape, dtype, device)
def str_short(self):
453    def str_short(self):
454        return f'{str(self.dtype)}[{",".join(str(d) for d in self.shape)}]'
class Operator:
473class Operator:
474    def __init__(self, name, variadic_inputs=False, nary_outputs=False):
475        self.name = name
476        self.variadic_inputs = variadic_inputs
477        self.nary_outputs = nary_outputs
478        if self.variadic_inputs:
479            self.reorg_args = self.reorg_args_nary
480
481    def __hash__(self):
482        return hash(self.name)
483
484    def __eq__(self, other):
485        if not isinstance(other, Operator):
486            return False
487        return self.name == other.name
488
489    def args_fixer(self, *args, **params):
490        return args, params
491
492    def __call__(self, *args, **params):
493        args, params = self.reorg_args(args, params)
494        args, params = self.args_fixer(*args, **params)
495        ret = bind(self, *args, **params)
496        if not self.nary_outputs:
497            ret = ret[0]
498        return ret
499
500    def __repr__(self) -> str:
501        return f"<{self.name}>"
502
503    def typecheck(self, *args, **params):
504        raise NotImplementedError
505
506    def jvp(self, *args, **params):
507        raise NotImplementedError
508
509    def T(self, *args, **params):
510        raise NotImplementedError
511
512    def vmap(self, *args, **params):
513        raise NotImplementedError
514
515    def reorg_args(self, args, params):
516        sig = inspect.signature(self.typecheck)
517        args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"]
518        params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"]
519
520        if args:
521            if len(args) > len(args_strs):
522                args, rest = args[: len(args_strs)], args[len(args_strs) :]
523                if params_strs:
524                    new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params}
525                    params = {**new_params, **params}
526            else:
527                args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)])
528                assert len(args) == len(args_strs)
529        return args, params
530
531    def reorg_args_nary(self, args, params):
532        return args, params
533
534    def partial_run(self, trace, tracers, **params):
535        tracers_in = [trace.instantiate_const(t) for t in tracers]
536        symvals_in = [t.symval for t in tracers_in]
537        symvals_out = self.typecheck(*symvals_in, **params)
538        tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out]
539        instruction = InstructionDraft(
540            self,
541            tracers_in,
542            params,
543            symvals_out,
544            list_map(weakref.ref, tracers_out),
545        )
546        for t in tracers_out:
547            t.draft = instruction
548        return tracers_out
549
550    def partial_run_instruction(self, unks_in, instruction):
551        if any(unks_in):
552            instruction1 = None
553            instruction2 = Instruction(
554                instruction.op,
555                instruction.inputs,
556                instruction.params,
557                instruction.out_binders,
558            )
559            unks_out = [True for i in instruction.out_binders]
560            res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)]
561        else:
562            instruction1 = instruction
563            instruction2 = None
564            unks_out = [False for i in instruction.out_binders]
565            res = None
566
567        return instruction1, instruction2, unks_out, res
Operator(name, variadic_inputs=False, nary_outputs=False)
474    def __init__(self, name, variadic_inputs=False, nary_outputs=False):
475        self.name = name
476        self.variadic_inputs = variadic_inputs
477        self.nary_outputs = nary_outputs
478        if self.variadic_inputs:
479            self.reorg_args = self.reorg_args_nary
name
variadic_inputs
nary_outputs
def args_fixer(self, *args, **params):
489    def args_fixer(self, *args, **params):
490        return args, params
def typecheck(self, *args, **params):
503    def typecheck(self, *args, **params):
504        raise NotImplementedError
def jvp(self, *args, **params):
506    def jvp(self, *args, **params):
507        raise NotImplementedError
def T(self, *args, **params):
509    def T(self, *args, **params):
510        raise NotImplementedError
def vmap(self, *args, **params):
512    def vmap(self, *args, **params):
513        raise NotImplementedError
def reorg_args(self, args, params):
515    def reorg_args(self, args, params):
516        sig = inspect.signature(self.typecheck)
517        args_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and k != "self"]
518        params_strs = [k for k, v in sig.parameters.items() if v.kind == inspect.Parameter.KEYWORD_ONLY and k != "self"]
519
520        if args:
521            if len(args) > len(args_strs):
522                args, rest = args[: len(args_strs)], args[len(args_strs) :]
523                if params_strs:
524                    new_params = {k: rest_arg for k, rest_arg in zip(params_strs, rest) if k not in params}
525                    params = {**new_params, **params}
526            else:
527                args = tuple([params[k] if k in params else arg for k, arg in zip(args_strs, args)])
528                assert len(args) == len(args_strs)
529        return args, params
def reorg_args_nary(self, args, params):
531    def reorg_args_nary(self, args, params):
532        return args, params
def partial_run(self, trace, tracers, **params):
534    def partial_run(self, trace, tracers, **params):
535        tracers_in = [trace.instantiate_const(t) for t in tracers]
536        symvals_in = [t.symval for t in tracers_in]
537        symvals_out = self.typecheck(*symvals_in, **params)
538        tracers_out = [PartialRunTraceTensor(trace, make_unknown_pval(symval), None) for symval in symvals_out]
539        instruction = InstructionDraft(
540            self,
541            tracers_in,
542            params,
543            symvals_out,
544            list_map(weakref.ref, tracers_out),
545        )
546        for t in tracers_out:
547            t.draft = instruction
548        return tracers_out
def partial_run_instruction(self, unks_in, instruction):
550    def partial_run_instruction(self, unks_in, instruction):
551        if any(unks_in):
552            instruction1 = None
553            instruction2 = Instruction(
554                instruction.op,
555                instruction.inputs,
556                instruction.params,
557                instruction.out_binders,
558            )
559            unks_out = [True for i in instruction.out_binders]
560            res = [v for unk, v in zip(unks_in, instruction.inputs) if ((not unk) and type(v) is Var)]
561        else:
562            instruction1 = instruction
563            instruction2 = None
564            unks_out = [False for i in instruction.out_binders]
565            res = None
566
567        return instruction1, instruction2, unks_out, res
class MetaOperator(Operator):
570class MetaOperator(Operator):
571    def meta_impl(self, *args, **kwargs):
572        raise NotImplementedError
def meta_impl(self, *args, **kwargs):
571    def meta_impl(self, *args, **kwargs):
572        raise NotImplementedError
class UnaryOperator(Operator):
575class UnaryOperator(Operator):
576    def vmap(self, x, *, dim_size, vals_in, dims_in, **params):
577        (x,), (x_bdim,) = vals_in, dims_in
578        return [self(x, **params)], [x_bdim]
579
580    def typecheck(self, x, **params):
581        return [SymbolicTensor.like(x)]
582
583    def jvp(self, primals, tangents, **params):
584        (x,), (x_dot,) = primals, tangents
585        return [self(x, **params)], [self(x_dot, **params)]
def vmap(self, x, *, dim_size, vals_in, dims_in, **params):
576    def vmap(self, x, *, dim_size, vals_in, dims_in, **params):
577        (x,), (x_bdim,) = vals_in, dims_in
578        return [self(x, **params)], [x_bdim]
def typecheck(self, x, **params):
580    def typecheck(self, x, **params):
581        return [SymbolicTensor.like(x)]
def jvp(self, primals, tangents, **params):
583    def jvp(self, primals, tangents, **params):
584        (x,), (x_dot,) = primals, tangents
585        return [self(x, **params)], [self(x_dot, **params)]
class BinaryOperator(Operator):
588class BinaryOperator(Operator):
589    boolean_output = False
590
591    def args_fixer(self, x, w, **params):
592        if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal:
593            assert x.shape == w.shape
594            return (x, w), params
595
596        if type(x) in TraceTensor.PYTHON_TYPES:
597            x = backend.full(shape=(), fill_value=x, dtype=w.dtype)
598        elif type(w) in TraceTensor.PYTHON_TYPES:
599            w = backend.full(shape=(), fill_value=w, dtype=x.dtype)
600
601        shape_delta = x.ndim - w.ndim
602        if shape_delta > 0:
603            w = w.reshape((1,) * shape_delta + w.shape)
604        elif shape_delta < 0:
605            x = x.reshape((1,) * -shape_delta + x.shape)
606
607        shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)])
608        if x.shape != shape_ret:
609            x = x.expand(shape_ret)
610        if w.shape != shape_ret:
611            w = w.expand(shape_ret)
612
613        if type(x) is Tensor and isinstance(w, TraceTensor):
614            x = w._trace.pure(x)
615        elif type(w) is Tensor and isinstance(x, TraceTensor):
616            w = x._trace.pure(w)
617        # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html
618        if x.dtype != w.dtype:
619            # {int, bool} -> float
620            if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype):
621                if dtypes.is_float(w.dtype):
622                    x = x.cast(w.dtype)
623                elif dtypes.is_float(x.dtype):
624                    w = w.cast(x.dtype)
625            # bool -> int
626            elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype):
627                if dtypes.is_int(w.dtype):
628                    x = x.cast(w.dtype)
629                elif dtypes.is_int(x.dtype):
630                    w = w.cast(x.dtype)
631            else:  # TODO: fine-grained type promotions
632                raise NotImplementedError("No other type promotion rules")
633
634        return (x, w), params
635
636    def vmap(self, dim_size, vals_in, dims_in, **params):
637        (x, w), (x_bdim, w_bdim) = vals_in, dims_in
638        if x_bdim != w_bdim:
639            if x_bdim is None:
640                x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim)
641                x_bdim = w_bdim
642            else:
643                w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim)
644        return [self(x, w, **params)], [x_bdim]
645
646    def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]:
647        if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)):
648            raise TypeError
649        symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype)
650        symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype)
651        if x.dtype != y.dtype:
652            raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})")
653        if symx == symy:
654            return [symx]
655        shape_delta = len(symx.shape) - len(symy.shape)
656        if shape_delta > 0:
657            symy = symy.like(shape=(1,) * shape_delta + symy.shape)
658        elif shape_delta < 0:
659            symx = symx.like(shape=(1,) * -shape_delta + symx.shape)
660        if symx == symy:
661            return [symx]
662        else:
663            shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)])
664            if symx.shape != shape_ret:
665                symx = symx.like(shape=shape_ret)
666            if symy.shape != shape_ret:
667                symy = symx.like(shape=shape_ret)
668            if symx != symy:
669                raise TypeError(f"symx ({symx}) != symy ({symy})")
670            return [symx]
671
672    def jvp(self, primals, tangents, **params):
673        (x, w), (x_dot, w_dot) = primals, tangents
674        return [self(x, w, **params)], [self(x_dot, w_dot, **params)]
675
676    def T(self, cotangents, x, w):
677        (gL_y,) = cotangents
678        if self.boolean_output:
679            gL_y = gL_y.cast(x.dtype)
680        if isinstance(x, UndefinedPrimal):
681            return [gL_y, NullCotangent]
682        elif isinstance(w, UndefinedPrimal):
683            return [NullCotangent, gL_y]
684        else:
685            raise ValueError
boolean_output = False
def args_fixer(self, x, w, **params):
591    def args_fixer(self, x, w, **params):
592        if isinstance(x, UndefinedPrimal) or type(w) is UndefinedPrimal:
593            assert x.shape == w.shape
594            return (x, w), params
595
596        if type(x) in TraceTensor.PYTHON_TYPES:
597            x = backend.full(shape=(), fill_value=x, dtype=w.dtype)
598        elif type(w) in TraceTensor.PYTHON_TYPES:
599            w = backend.full(shape=(), fill_value=w, dtype=x.dtype)
600
601        shape_delta = x.ndim - w.ndim
602        if shape_delta > 0:
603            w = w.reshape((1,) * shape_delta + w.shape)
604        elif shape_delta < 0:
605            x = x.reshape((1,) * -shape_delta + x.shape)
606
607        shape_ret = tuple([max(x, w) for x, w in zip(x.shape, w.shape)])
608        if x.shape != shape_ret:
609            x = x.expand(shape_ret)
610        if w.shape != shape_ret:
611            w = w.expand(shape_ret)
612
613        if type(x) is Tensor and isinstance(w, TraceTensor):
614            x = w._trace.pure(x)
615        elif type(w) is Tensor and isinstance(x, TraceTensor):
616            w = x._trace.pure(w)
617        # TODO: https://jax.readthedocs.io/en/latest/type_promotion.html
618        if x.dtype != w.dtype:
619            # {int, bool} -> float
620            if dtypes.is_float(x.dtype) ^ dtypes.is_float(w.dtype):
621                if dtypes.is_float(w.dtype):
622                    x = x.cast(w.dtype)
623                elif dtypes.is_float(x.dtype):
624                    w = w.cast(x.dtype)
625            # bool -> int
626            elif dtypes.is_int(x.dtype) ^ dtypes.is_int(w.dtype):
627                if dtypes.is_int(w.dtype):
628                    x = x.cast(w.dtype)
629                elif dtypes.is_int(x.dtype):
630                    w = w.cast(x.dtype)
631            else:  # TODO: fine-grained type promotions
632                raise NotImplementedError("No other type promotion rules")
633
634        return (x, w), params
def vmap(self, dim_size, vals_in, dims_in, **params):
636    def vmap(self, dim_size, vals_in, dims_in, **params):
637        (x, w), (x_bdim, w_bdim) = vals_in, dims_in
638        if x_bdim != w_bdim:
639            if x_bdim is None:
640                x = VMapTrace.move_vmap_dim(x, dim_size, x_bdim, w_bdim)
641                x_bdim = w_bdim
642            else:
643                w = VMapTrace.move_vmap_dim(w, dim_size, w_bdim, x_bdim)
644        return [self(x, w, **params)], [x_bdim]
def typecheck( self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]:
646    def typecheck(self, x: SymbolicTensor, y: SymbolicTensor, **params) -> List[SymbolicTensor]:
647        if not isinstance(x, (Tensor, SymbolicTensor)) or not isinstance(y, (Tensor, SymbolicTensor)):
648            raise TypeError
649        symx = SymbolicTensor.like(x, dtype=dtypes.bool if self.boolean_output else x.dtype)
650        symy = SymbolicTensor.like(y, dtype=dtypes.bool if self.boolean_output else y.dtype)
651        if x.dtype != y.dtype:
652            raise TypeError(f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})")
653        if symx == symy:
654            return [symx]
655        shape_delta = len(symx.shape) - len(symy.shape)
656        if shape_delta > 0:
657            symy = symy.like(shape=(1,) * shape_delta + symy.shape)
658        elif shape_delta < 0:
659            symx = symx.like(shape=(1,) * -shape_delta + symx.shape)
660        if symx == symy:
661            return [symx]
662        else:
663            shape_ret = tuple([max(x, w) for x, w in zip(symx.shape, symy.shape)])
664            if symx.shape != shape_ret:
665                symx = symx.like(shape=shape_ret)
666            if symy.shape != shape_ret:
667                symy = symx.like(shape=shape_ret)
668            if symx != symy:
669                raise TypeError(f"symx ({symx}) != symy ({symy})")
670            return [symx]
def jvp(self, primals, tangents, **params):
672    def jvp(self, primals, tangents, **params):
673        (x, w), (x_dot, w_dot) = primals, tangents
674        return [self(x, w, **params)], [self(x_dot, w_dot, **params)]
def T(self, cotangents, x, w):
676    def T(self, cotangents, x, w):
677        (gL_y,) = cotangents
678        if self.boolean_output:
679            gL_y = gL_y.cast(x.dtype)
680        if isinstance(x, UndefinedPrimal):
681            return [gL_y, NullCotangent]
682        elif isinstance(w, UndefinedPrimal):
683            return [NullCotangent, gL_y]
684        else:
685            raise ValueError
class ReduceOperator(Operator):
688class ReduceOperator(Operator):
689    def args_fixer(self, x, *, dim=None, keepdim=False):
690        if dim is None:
691            dim = tuple(range(x.ndim))
692        elif isinstance(dim, int):
693            dim = (dim,)
694        dim = tuple(a if a >= 0 else a + len(x.shape) for a in dim)
695        return (x,), dict(dim=dim, keepdim=keepdim)
696
697    def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim):
698        (x,), (x_bdim,) = vals_in, dims_in
699        dim = tuple(a + (x_bdim <= a) for a in dim)
700        out_bdim = x_bdim - sum(a < x_bdim for a in dim)
701        return [self(x, dim=dim, keepdim=keepdim)], [out_bdim]
702
703    def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]:
704        dim = list(set([a + len(x.shape) if a < 0 else a for a in dim]))
705        if keepdim:
706            new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)]
707        else:
708            new_shape = [d for i, d in enumerate(x.shape) if i not in dim]
709        return [SymbolicTensor.like(x, shape=tuple(new_shape))]
def args_fixer(self, x, *, dim=None, keepdim=False):
689    def args_fixer(self, x, *, dim=None, keepdim=False):
690        if dim is None:
691            dim = tuple(range(x.ndim))
692        elif isinstance(dim, int):
693            dim = (dim,)
694        dim = tuple(a if a >= 0 else a + len(x.shape) for a in dim)
695        return (x,), dict(dim=dim, keepdim=keepdim)
def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim):
697    def vmap(self, dim_size, vals_in, dims_in, *, dim, keepdim):
698        (x,), (x_bdim,) = vals_in, dims_in
699        dim = tuple(a + (x_bdim <= a) for a in dim)
700        out_bdim = x_bdim - sum(a < x_bdim for a in dim)
701        return [self(x, dim=dim, keepdim=keepdim)], [out_bdim]
def typecheck( self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]:
703    def typecheck(self, x: SymbolicTensor, *, dim=None, keepdim=False) -> List[SymbolicTensor]:
704        dim = list(set([a + len(x.shape) if a < 0 else a for a in dim]))
705        if keepdim:
706            new_shape = [d if i not in dim else 1 for i, d in enumerate(x.shape)]
707        else:
708            new_shape = [d for i, d in enumerate(x.shape) if i not in dim]
709        return [SymbolicTensor.like(x, shape=tuple(new_shape))]
class InitOperator(Operator):
712class InitOperator(Operator):
713    def vmap(self, dim_size, vals_in, dims_in, **params):
714        (x_bdim,) = dims_in
715        y = self(**params)
716        y = y.unsqueeze(x_bdim)
717        return [y], [x_bdim]
718
719    def jvp(self, primals, tangents, **params):
720        y = self(**params)
721        y_dot = NullCotangent(y.symval)
722        return [y], [y_dot]
723
724    def T(self, cotangents, **params):
725        return [NullCotangent(cotangents[0])]
def vmap(self, dim_size, vals_in, dims_in, **params):
713    def vmap(self, dim_size, vals_in, dims_in, **params):
714        (x_bdim,) = dims_in
715        y = self(**params)
716        y = y.unsqueeze(x_bdim)
717        return [y], [x_bdim]
def jvp(self, primals, tangents, **params):
719    def jvp(self, primals, tangents, **params):
720        y = self(**params)
721        y_dot = NullCotangent(y.symval)
722        return [y], [y_dot]
def T(self, cotangents, **params):
724    def T(self, cotangents, **params):
725        return [NullCotangent(cotangents[0])]
class ShapeOperator(Operator):
728class ShapeOperator(Operator):
729    pass
class GeneralReduceOperator(Operator):
732class GeneralReduceOperator(Operator):
733    pass
class OperatorSet:
736class OperatorSet:
737    def __init__(self):
738        self.register("jit_op")(JitOp)
739
740    def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()):
741        def wrap(op_cls):
742            assert name not in vars(self)
743            op = op_cls(name, variadic_inputs, nary_outputs)
744            setattr(self, name, op)
745            for a in aliases:
746                setattr(self, a, op)
747            return op_cls
748
749        return wrap
def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()):
740    def register(self, name, variadic_inputs=False, nary_outputs=False, aliases=()):
741        def wrap(op_cls):
742            assert name not in vars(self)
743            op = op_cls(name, variadic_inputs, nary_outputs)
744            setattr(self, name, op)
745            for a in aliases:
746                setattr(self, a, op)
747            return op_cls
748
749        return wrap
class ProcedureSet:
752class ProcedureSet:
753    def register(self, aliases=()):
754        def wrap(f):
755            assert f.__name__ not in vars(self)
756            setattr(self, f.__name__, f)
757            for a in aliases:
758                setattr(self, a, f)
759            return f
760
761        return wrap
def register(self, aliases=()):
753    def register(self, aliases=()):
754        def wrap(f):
755            assert f.__name__ not in vars(self)
756            setattr(self, f.__name__, f)
757            for a in aliases:
758                setattr(self, a, f)
759            return f
760
761        return wrap
class CodegenOutput(typing.NamedTuple):
764class CodegenOutput(NamedTuple):
765    code_lines: List[str]
766    fn_defs: Dict[str, List[str]]
767    in_binders: List["ProgramEnvVar"]
768    outs: List["ProgramEnvVar"]

CodegenOutput(code_lines, fn_defs, in_binders, outs)

CodegenOutput( code_lines: List[str], fn_defs: Dict[str, List[str]], in_binders: List[ForwardRef('ProgramEnvVar')], outs: List[ForwardRef('ProgramEnvVar')])

Create new instance of CodegenOutput(code_lines, fn_defs, in_binders, outs)

code_lines: List[str]

Alias for field number 0

fn_defs: Dict[str, List[str]]

Alias for field number 1

in_binders: List[ProgramEnvVar]

Alias for field number 2

outs: List[ProgramEnvVar]

Alias for field number 3

Inherited Members
builtins.tuple
index
count
class Backend:
771class Backend:
772    LOG_LRU = int(os.environ.get("LOG_LRU", 0))
773    LOG_JIT = int(os.environ.get("LOG_JIT", 0))
774    LOG_TREE = int(os.environ.get("LOG_TREE", 0))
775    LOG_BACKEND = int(os.environ.get("LOG_BACKEND", 0))
776    LOG_PROGRAM = int(os.environ.get("LOG_PROGRAM", 0))
777    LOG_INIT = int(os.environ.get("LOG_INIT", 1))
778    device_var = os.environ.get("DEFAULT_DEVICE", "cpu:0")
779    if device_var[-2] != ":":
780        device_var += ":0"
781    DEFAULT_DEVICE = devices.name_idx_device_map[device_var]
782    DEFAULT_DTYPE = dtypes.name_dtype_map[os.environ.get("DEFAULT_DTYPE", "float32")]
783    dtype_for_indices: DType = None  # need to override
784
785    def __init__(
786        self,
787        operator_set: OperatorSet,
788        procedure_set: ProcedureSet,
789    ):
790        self.operator_set = operator_set
791        self.procedure_set = procedure_set
792        self.node_types = dict()
793        self.impls = dict()
794        self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple")
795        self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list")
796        self.register_node(
797            dict,
798            lambda d: list_map(tuple, unzip2(sorted(d.items()))),
799            lambda keys, vals: dict(list_zip(keys, vals)),
800            "dict",
801        )
802        self.register_node(
803            UndefinedPrimal,
804            lambda u: (u.symval, ()),
805            lambda symval, _: UndefinedPrimal(symval),
806            "UndefinedPrimal",
807        )
808
809    def set_impl(self, op: Union[types.LambdaType, types.FunctionType]):
810        def set_impl_(fn):
811            self.impls[op] = types.MethodType(fn, self)
812
813        return set_impl_
814
815    def register_node(self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None:
816        if name is None:
817            name = str(ty)
818        self.node_types[ty] = NodeType(name, to_iter, from_iter)
819
820    def __getattr__(self, attr):
821        try:
822            dblog(
823                f"Looking {self}.{attr} in operator_set",
824                enable=backend.LOG_BACKEND,
825            )
826            return getattr(self.operator_set, attr)
827        except:
828            pass
829        try:
830            dblog(
831                f"Looking {self}.{attr} in procedure_set",
832                enable=backend.LOG_BACKEND,
833            )
834            return getattr(self.procedure_set, attr)
835        except:
836            pass
837        dblog(
838            f"Fallback to default {self} getattribute",
839            enable=backend.LOG_BACKEND,
840        )
841        super().__getattribute__(attr)
842
843    def tensor(
844        self,
845        val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
846        dtype: Optional[Any] = None,
847        device=None,
848    ):
849        if isinstance(val, TensorBuffer):
850            return Tensor(val)
851        elif isinstance(val, Tensor):
852            return val
853        if type(val) is bytes:
854            val = np.frombuffer(val, dtype=dtype)
855        return self.from_numpy(val, dtype, device)
856
857    def symbolic_tensor(
858        self,
859        shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
860        dtype: Optional[Any] = None,
861        device=None,
862    ):
863        dtype = dtype or self.DEFAULT_DTYPE
864        device = device or self.DEFAULT_DEVICE
865        return SymbolicTensor(shape, dtype, device)
866
867    def seed(self, seed):
868        raise NotImplementedError
869
870    @property
871    def default_dtype_value(self):
872        return self.dtype_map[backend.DEFAULT_DTYPE]
873
874    def set_method(self, method):
875        setattr(self, method.__name__, types.MethodType(method, self))
876
877    def from_numpy(self, val, device):
878        raise NotImplementedError
879
880    def numpy_of(self, tensor):
881        raise NotImplementedError
882
883    def device_of(self, tensor):
884        raise NotImplementedError
885
886    def shape_of(self, tensor):
887        raise NotImplementedError
888
889    def dtype_of(self, tensor):
890        raise NotImplementedError
891
892    @lru_cache_verbose()
893    def jit_program(
894        self,
895        hashed_program: Hashed,
896        hashed_consts: Tuple[Hashed, ...],
897    ):
898        program: Program = hashed_program.val
899        typecheck_program(program)
900        consts = [x.val for x in hashed_consts]
901        in_symvals = [v.symval for v in program.in_binders[len(consts) :]]
902        codegen_output: CodegenOutput = self.codegen(program, consts + in_symvals, fn_name="main")
903        fn, code = self.compile(codegen_output)
904        jit_output = JitOutput(program, codegen_output, fn, code, consts)
905        return jit_output
906
907    def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
908        "Returns compiler IR from the Program"
909        raise NotImplementedError
910
911    def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
912        "Compiles compiler IR to a Python callable function"
913        raise NotImplementedError
914
915    def export(self, jit_output, *args, **params):
916        raise NotImplementedError
917
918    def load(self, path, single_key="_tensor"):
919        with open(path, mode="rb") as f:
920            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m:
921                json_len = np.int64(m[0])
922                start = 8 + json_len
923                metadata = json.loads(m[8:start])
924                ret = {}
925                for k, v in metadata.items():
926                    if k != "__metadata__":
927                        dtype = Tensor.mlir_dtype_map[(v["dtype"])]
928                        data_start = start + v["data_offsets"][0]
929                        data_end = start + v["data_offsets"][1]
930                        t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy())
931                        t = backend.tensor(t_np, dtype=dtype)
932                        t = t.reshape(tuple(v["shape"]))
933                        ret[k] = t
934                if len(ret) == 1 and single_key in ret.keys():
935                    return ret[single_key]
936                return ret
937
938    def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"):
939        if isinstance(tensors, Tensor):
940            tensors = {single_key: tensors}
941        else:
942            assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items())
943
944        metadata, offset = {}, 0
945        for k, v in tensors.items():
946            metadata[k] = {
947                "dtype": v.dtype.mlir,
948                "shape": list(v.shape),
949                "data_offsets": [offset, offset + v.nbytes()],
950            }
951            offset += v.nbytes()
952        j = json.dumps(metadata, separators=(",", ":"))
953        Path(path).unlink(missing_ok=True)
954        jbytes = j.encode("utf-8")
955        start = 8 + len(jbytes)
956        with open(path, mode="wb") as f:  # make empty file, fill with enough space
957            f.write(b"\x00" * (start + offset))
958        with open(path, mode="r+b") as f:
959            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m:
960                m[0:8] = np.int64(len(j)).tobytes()
961                m[8:start] = jbytes
962                for t, tm in zip(tensors.values(), metadata.values()):
963                    data_start, data_end = tm["data_offsets"]
964                    m[start + data_start : start + data_end] = t.numpy().tobytes()
Backend( operator_set: OperatorSet, procedure_set: ProcedureSet)
785    def __init__(
786        self,
787        operator_set: OperatorSet,
788        procedure_set: ProcedureSet,
789    ):
790        self.operator_set = operator_set
791        self.procedure_set = procedure_set
792        self.node_types = dict()
793        self.impls = dict()
794        self.register_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs), "tuple")
795        self.register_node(list, lambda l: (None, l), lambda _, xs: list(xs), "list")
796        self.register_node(
797            dict,
798            lambda d: list_map(tuple, unzip2(sorted(d.items()))),
799            lambda keys, vals: dict(list_zip(keys, vals)),
800            "dict",
801        )
802        self.register_node(
803            UndefinedPrimal,
804            lambda u: (u.symval, ()),
805            lambda symval, _: UndefinedPrimal(symval),
806            "UndefinedPrimal",
807        )
LOG_LRU = 0
LOG_JIT = 0
LOG_TREE = 0
LOG_BACKEND = 0
LOG_PROGRAM = 0
LOG_INIT = 1
device_var = 'cpu:0'
DEFAULT_DEVICE = <Device: 'cpu:0'>
DEFAULT_DTYPE = <DType: float32>
dtype_for_indices: DType = None
operator_set
procedure_set
node_types
impls
def set_impl(self, op: function):
809    def set_impl(self, op: Union[types.LambdaType, types.FunctionType]):
810        def set_impl_(fn):
811            self.impls[op] = types.MethodType(fn, self)
812
813        return set_impl_
def register_node( self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None:
815    def register_node(self, ty: Type, to_iter: Callable, from_iter: Callable, name=None) -> None:
816        if name is None:
817            name = str(ty)
818        self.node_types[ty] = NodeType(name, to_iter, from_iter)
def tensor( self, val: Union[list, tuple, numpy.ndarray, TensorBuffer] = None, dtype: Optional[Any] = None, device=None):
843    def tensor(
844        self,
845        val: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
846        dtype: Optional[Any] = None,
847        device=None,
848    ):
849        if isinstance(val, TensorBuffer):
850            return Tensor(val)
851        elif isinstance(val, Tensor):
852            return val
853        if type(val) is bytes:
854            val = np.frombuffer(val, dtype=dtype)
855        return self.from_numpy(val, dtype, device)
def symbolic_tensor( self, shape: Union[list, tuple, numpy.ndarray, TensorBuffer] = None, dtype: Optional[Any] = None, device=None):
857    def symbolic_tensor(
858        self,
859        shape: Union[list, tuple, np.ndarray, "TensorBuffer"] = None,
860        dtype: Optional[Any] = None,
861        device=None,
862    ):
863        dtype = dtype or self.DEFAULT_DTYPE
864        device = device or self.DEFAULT_DEVICE
865        return SymbolicTensor(shape, dtype, device)
def seed(self, seed):
867    def seed(self, seed):
868        raise NotImplementedError
default_dtype_value
870    @property
871    def default_dtype_value(self):
872        return self.dtype_map[backend.DEFAULT_DTYPE]
def set_method(self, method):
874    def set_method(self, method):
875        setattr(self, method.__name__, types.MethodType(method, self))
def from_numpy(self, val, device):
877    def from_numpy(self, val, device):
878        raise NotImplementedError
def numpy_of(self, tensor):
880    def numpy_of(self, tensor):
881        raise NotImplementedError
def device_of(self, tensor):
883    def device_of(self, tensor):
884        raise NotImplementedError
def shape_of(self, tensor):
886    def shape_of(self, tensor):
887        raise NotImplementedError
def dtype_of(self, tensor):
889    def dtype_of(self, tensor):
890        raise NotImplementedError
def jit_program(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
def codegen( self, program: Program, args: Tuple, in_symvals: Tuple, name: str):
907    def codegen(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
908        "Returns compiler IR from the Program"
909        raise NotImplementedError

Returns compiler IR from the Program

def compile( self, program: Program, args: Tuple, in_symvals: Tuple, name: str):
911    def compile(self, program: "Program", args: Tuple, in_symvals: Tuple, name: str):
912        "Compiles compiler IR to a Python callable function"
913        raise NotImplementedError

Compiles compiler IR to a Python callable function

def export(self, jit_output, *args, **params):
915    def export(self, jit_output, *args, **params):
916        raise NotImplementedError
def load(self, path, single_key='_tensor'):
918    def load(self, path, single_key="_tensor"):
919        with open(path, mode="rb") as f:
920            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as m:
921                json_len = np.int64(m[0])
922                start = 8 + json_len
923                metadata = json.loads(m[8:start])
924                ret = {}
925                for k, v in metadata.items():
926                    if k != "__metadata__":
927                        dtype = Tensor.mlir_dtype_map[(v["dtype"])]
928                        data_start = start + v["data_offsets"][0]
929                        data_end = start + v["data_offsets"][1]
930                        t_np = np.frombuffer(m[data_start:data_end], dtype=dtype.numpy())
931                        t = backend.tensor(t_np, dtype=dtype)
932                        t = t.reshape(tuple(v["shape"]))
933                        ret[k] = t
934                if len(ret) == 1 and single_key in ret.keys():
935                    return ret[single_key]
936                return ret
def save( self, tensors: Dict[str, Tensor], path: str, single_key='_tensor'):
938    def save(self, tensors: Dict[str, Tensor], path: str, single_key="_tensor"):
939        if isinstance(tensors, Tensor):
940            tensors = {single_key: tensors}
941        else:
942            assert all((isinstance(k, str) and isinstance(v, Tensor)) for k, v in tensors.items())
943
944        metadata, offset = {}, 0
945        for k, v in tensors.items():
946            metadata[k] = {
947                "dtype": v.dtype.mlir,
948                "shape": list(v.shape),
949                "data_offsets": [offset, offset + v.nbytes()],
950            }
951            offset += v.nbytes()
952        j = json.dumps(metadata, separators=(",", ":"))
953        Path(path).unlink(missing_ok=True)
954        jbytes = j.encode("utf-8")
955        start = 8 + len(jbytes)
956        with open(path, mode="wb") as f:  # make empty file, fill with enough space
957            f.write(b"\x00" * (start + offset))
958        with open(path, mode="r+b") as f:
959            with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_WRITE) as m:
960                m[0:8] = np.int64(len(j)).tobytes()
961                m[8:start] = jbytes
962                for t, tm in zip(tensors.values(), metadata.values()):
963                    data_start, data_end = tm["data_offsets"]
964                    m[start + data_start : start + data_end] = t.numpy().tobytes()
class Var:
972class Var:
973    def __init__(self, symval):
974        self.symval = symval
975        self.val = None
Var(symval)
973    def __init__(self, symval):
974        self.symval = symval
975        self.val = None
symval
val
class Lit:
978class Lit:
979    def __init__(self, val):
980        self.symval = SymbolicTensor.like(get_symval(val))
981        self.val = val
Lit(val)
979    def __init__(self, val):
980        self.symval = SymbolicTensor.like(get_symval(val))
981        self.val = val
symval
val
Atom = typing.Union[Var, Lit]
class Instruction(typing.NamedTuple):
987class Instruction(NamedTuple):
988    op: Operator
989    inputs: List[Atom]
990    params: Dict[str, Any]
991    out_binders: List[Atom]

Instruction(op, inputs, params, out_binders)

Instruction( op: Operator, inputs: List[Union[Var, Lit]], params: Dict[str, Any], out_binders: List[Union[Var, Lit]])

Create new instance of Instruction(op, inputs, params, out_binders)

op: Operator

Alias for field number 0

inputs: List[Union[Var, Lit]]

Alias for field number 1

params: Dict[str, Any]

Alias for field number 2

out_binders: List[Union[Var, Lit]]

Alias for field number 3

Inherited Members
builtins.tuple
index
count
class ProgramEnvVar(typing.NamedTuple):
 994class ProgramEnvVar(NamedTuple):
 995    name: str
 996    symval: SymbolicTensor
 997    is_const: bool = False
 998
 999    @property
1000    def shape(self):
1001        return self.symval.shape
1002
1003    @property
1004    def dtype(self):
1005        return self.symval.dtype
1006
1007    @property
1008    def device(self):
1009        return self.symval.device
1010
1011    @property
1012    def ndim(self):
1013        return self.symval.ndim
1014
1015    def numpy(self):
1016        return self.symval.numpy()
1017
1018    def __repr__(self):
1019        return f"<ProgramEnvVar: name={self.name}, symval={self.symval}>"
1020
1021    str_short = __repr__

ProgramEnvVar(name, symval, is_const)

ProgramEnvVar(name: str, symval: SymbolicTensor, is_const: bool = False)

Create new instance of ProgramEnvVar(name, symval, is_const)

name: str

Alias for field number 0

symval: SymbolicTensor

Alias for field number 1

is_const: bool

Alias for field number 2

shape
 999    @property
1000    def shape(self):
1001        return self.symval.shape
dtype
1003    @property
1004    def dtype(self):
1005        return self.symval.dtype
device
1007    @property
1008    def device(self):
1009        return self.symval.device
ndim
1011    @property
1012    def ndim(self):
1013        return self.symval.ndim
def numpy(self):
1015    def numpy(self):
1016        return self.symval.numpy()
def str_short(self):
1018    def __repr__(self):
1019        return f"<ProgramEnvVar: name={self.name}, symval={self.symval}>"

Return repr(self).

Inherited Members
builtins.tuple
index
count
class Program:
1024class Program:
1025    def __init__(
1026        self,
1027        in_binders: Any,
1028        instructions: Tuple[Instruction],
1029        outs: Any,
1030        num_consts: int = 0,
1031        static_args: Any = (),
1032        name: str = "my_program",
1033        indent_amount=4,
1034    ):
1035        self.in_binders: Any = in_binders
1036        self.outs: Any = outs
1037        self.instructions = self.prune_instructions(instructions, outs)
1038        self.num_consts: int = num_consts
1039        self.static_args = static_args
1040        self.name: str = name
1041        self.indent_amount: int = indent_amount
1042
1043        self.env: Dict[ProgramEnvVar, Any] = dict()
1044        for inb in self.in_binders:
1045            prefix = "x" if type(inb.symval) is SymbolicTensor else "c"
1046            idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1047            self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False)
1048        for instruction in self.instructions:
1049            if len(instruction.out_binders) == 0:
1050                continue
1051            for outb in instruction.out_binders:
1052                prefix = "y" if outb in self.outs else "z"
1053                idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1054                self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval)
1055        self.curr_repr = repr(self)
1056
1057    def pprint_shape(self, symval, scalar_as_empty_array=False):
1058        xdtype = symval.dtype.mlir
1059        if len(symval.shape) > 0:
1060            xshape = f"{', '.join((repr(i) for i in symval.shape))}"
1061            return f"[{xshape}, {xdtype}]"
1062        else:
1063            return f"[{xdtype}]"
1064
1065    def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False):
1066        in_code = ", ".join(self.pprint_shape(t) for t in in_symvals)
1067        in_code = f"({in_code})" if len(in_symvals) > 1 else in_code
1068        out_code = ", ".join(self.pprint_shape(t) for t in out_symvals)
1069        out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code
1070        typing_code = f"{in_code} -> {out_code}"
1071        return typing_code
1072
1073    def __repr__(self):
1074        fn_defs = self.instructions_as_code(self, dict())
1075        return "\n".join(line for code_lines in fn_defs.values() for line in code_lines)
1076
1077    def save(self, *args, dir_path="/tmp/slope_program", dry_run=False):
1078        os.makedirs(dir_path, exist_ok=True)
1079        head_code_lines = [f"import slope # backend={backend.__class__.__name__}"]
1080        fn_defs = self.instructions_as_code(self, dict())
1081        in_binders_vars = [self.env[i] for i in self.in_binders]
1082        for i in range(len(self.in_binders)):
1083            ibv = in_binders_vars[i]
1084            if ibv.is_const:
1085                const_filename = f"{ibv.name}.safetensors"
1086                const_path = os.path.join(dir_path, f"{const_filename}")
1087                if not dry_run:
1088                    backend.save(args[i], const_path)
1089                dblog(
1090                    f"Saved {ibv.name} at {const_path}",
1091                    enable=backend.LOG_BACKEND,
1092                )
1093                head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""]
1094        head_code_lines += [""]
1095        code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines])
1096        dblog(
1097            f"Contents of {self.name}:\n```\n{code}\n```",
1098            enable=backend.LOG_BACKEND,
1099        )
1100        program_path = os.path.join(dir_path, "main.py")
1101        if not dry_run:
1102            with open(program_path, "w") as f:
1103                f.write(code)
1104        dblog(
1105            f"Saved program {self.name} at {program_path}",
1106            enable=backend.LOG_BACKEND,
1107        )
1108        ls_contents = "\n\t".join(os.listdir(dir_path))
1109        dblog(
1110            f"Contents of {dir_path}:\n\t{ls_contents}",
1111            enable=backend.LOG_BACKEND,
1112        )
1113
1114    def __hash__(self):
1115        return hash(self.curr_repr)
1116
1117    def __eq__(self, other):
1118        return self is other
1119
1120    @classmethod
1121    def instructions_as_code(cls, program, fn_defs):
1122        def indent(code, indent_amount):
1123            spaces = " " * (len(code) - len(code.lstrip()))
1124            spaces += " " * indent_amount
1125            return "\n".join([spaces + line for line in code.strip().split("\n")])
1126
1127        in_binders_vars = [program.env[i] for i in program.in_binders]
1128        body_code_lines = []
1129        for instruction in program.instructions:
1130            if len(instruction.out_binders) == 0:
1131                continue
1132            params = instruction.params.copy()
1133            for param_name, param in params.items():
1134                if isinstance(param, Program):
1135                    sub_program = param
1136                    fn_defs = cls.instructions_as_code(sub_program, fn_defs)
1137                    program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1138                    params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]"
1139                if isinstance(param, DType):
1140                    params[param_name] = f"slope.{param.name}"
1141            param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items())
1142            in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1143            out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders)
1144            sig = program.pprint_sig(
1145                [program.env[x].symval for x in instruction.inputs],
1146                [program.env[y].symval for y in instruction.out_binders],
1147            )
1148            line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}"""
1149            body_code_lines += [indent(line, program.indent_amount)]
1150
1151        fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars])
1152        # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args])
1153        out_vars = [program.env[o] for o in program.outs]
1154        fn_sig = program.pprint_sig(
1155            [i.symval for i in in_binders_vars],
1156            [o.symval for o in out_vars],
1157        )
1158        head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"]
1159        out_str = ", ".join([f"{o.name}" for o in out_vars])
1160        tail_code_line = [indent(f"return {out_str}", program.indent_amount)]
1161        code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"]
1162
1163        fn_defs[program.name] = code_lines
1164        return fn_defs
1165
1166    @staticmethod
1167    def prune_instructions(instructions, outs):
1168        graph = dict()
1169        for instruction in instructions:
1170            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1171            for parent in parent_nodes:
1172                if parent not in graph:
1173                    graph[parent] = set()
1174                for child in child_nodes:
1175                    graph[parent].add(child)
1176        visited_from_terminal = set()
1177
1178        def dfs(node, visited):
1179            visited.add(node)
1180            if node in graph:
1181                for neighbor in graph[node]:
1182                    if neighbor not in visited:
1183                        dfs(neighbor, visited)
1184
1185        for terminal_node in outs:
1186            dfs(terminal_node, visited_from_terminal)
1187        unreachable_nodes = set(graph.keys()) - visited_from_terminal
1188
1189        instructions_to_prune = []
1190        for instruction in instructions:
1191            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1192            if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes):
1193                instructions_to_prune += [instruction]
1194        new_instructions = [inst for inst in instructions if inst not in instructions_to_prune]
1195        if backend.LOG_PROGRAM:
1196            LI = len(instructions)
1197            LNI = len(new_instructions)
1198            DIFF = LI - LNI
1199            UN = len(unreachable_nodes)
1200            dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}")
1201        return new_instructions
Program( in_binders: Any, instructions: Tuple[Instruction], outs: Any, num_consts: int = 0, static_args: Any = (), name: str = 'my_program', indent_amount=4)
1025    def __init__(
1026        self,
1027        in_binders: Any,
1028        instructions: Tuple[Instruction],
1029        outs: Any,
1030        num_consts: int = 0,
1031        static_args: Any = (),
1032        name: str = "my_program",
1033        indent_amount=4,
1034    ):
1035        self.in_binders: Any = in_binders
1036        self.outs: Any = outs
1037        self.instructions = self.prune_instructions(instructions, outs)
1038        self.num_consts: int = num_consts
1039        self.static_args = static_args
1040        self.name: str = name
1041        self.indent_amount: int = indent_amount
1042
1043        self.env: Dict[ProgramEnvVar, Any] = dict()
1044        for inb in self.in_binders:
1045            prefix = "x" if type(inb.symval) is SymbolicTensor else "c"
1046            idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1047            self.env[inb] = ProgramEnvVar(f"{prefix}{idx}", inb.symval, True if prefix == "c" else False)
1048        for instruction in self.instructions:
1049            if len(instruction.out_binders) == 0:
1050                continue
1051            for outb in instruction.out_binders:
1052                prefix = "y" if outb in self.outs else "z"
1053                idx = sum([1 if v.name[0] == prefix else 0 for v in self.env.values()])
1054                self.env[outb] = ProgramEnvVar(f"{prefix}{idx}", outb.symval)
1055        self.curr_repr = repr(self)
in_binders: Any
outs: Any
instructions
num_consts: int
static_args
name: str
indent_amount: int
env: Dict[ProgramEnvVar, Any]
curr_repr
def pprint_shape(self, symval, scalar_as_empty_array=False):
1057    def pprint_shape(self, symval, scalar_as_empty_array=False):
1058        xdtype = symval.dtype.mlir
1059        if len(symval.shape) > 0:
1060            xshape = f"{', '.join((repr(i) for i in symval.shape))}"
1061            return f"[{xshape}, {xdtype}]"
1062        else:
1063            return f"[{xdtype}]"
def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False):
1065    def pprint_sig(self, in_symvals, out_symvals, unpack_unary_output=False):
1066        in_code = ", ".join(self.pprint_shape(t) for t in in_symvals)
1067        in_code = f"({in_code})" if len(in_symvals) > 1 else in_code
1068        out_code = ", ".join(self.pprint_shape(t) for t in out_symvals)
1069        out_code = f"({out_code})" if len(out_symvals) > 1 or unpack_unary_output else out_code
1070        typing_code = f"{in_code} -> {out_code}"
1071        return typing_code
def save(self, *args, dir_path='/tmp/slope_program', dry_run=False):
1077    def save(self, *args, dir_path="/tmp/slope_program", dry_run=False):
1078        os.makedirs(dir_path, exist_ok=True)
1079        head_code_lines = [f"import slope # backend={backend.__class__.__name__}"]
1080        fn_defs = self.instructions_as_code(self, dict())
1081        in_binders_vars = [self.env[i] for i in self.in_binders]
1082        for i in range(len(self.in_binders)):
1083            ibv = in_binders_vars[i]
1084            if ibv.is_const:
1085                const_filename = f"{ibv.name}.safetensors"
1086                const_path = os.path.join(dir_path, f"{const_filename}")
1087                if not dry_run:
1088                    backend.save(args[i], const_path)
1089                dblog(
1090                    f"Saved {ibv.name} at {const_path}",
1091                    enable=backend.LOG_BACKEND,
1092                )
1093                head_code_lines += [f"""{ibv.name} = slope.load("./{const_filename}")"""]
1094        head_code_lines += [""]
1095        code = "\n".join(head_code_lines + [line for code_lines in fn_defs.values() for line in code_lines])
1096        dblog(
1097            f"Contents of {self.name}:\n```\n{code}\n```",
1098            enable=backend.LOG_BACKEND,
1099        )
1100        program_path = os.path.join(dir_path, "main.py")
1101        if not dry_run:
1102            with open(program_path, "w") as f:
1103                f.write(code)
1104        dblog(
1105            f"Saved program {self.name} at {program_path}",
1106            enable=backend.LOG_BACKEND,
1107        )
1108        ls_contents = "\n\t".join(os.listdir(dir_path))
1109        dblog(
1110            f"Contents of {dir_path}:\n\t{ls_contents}",
1111            enable=backend.LOG_BACKEND,
1112        )
@classmethod
def instructions_as_code(cls, program, fn_defs):
1120    @classmethod
1121    def instructions_as_code(cls, program, fn_defs):
1122        def indent(code, indent_amount):
1123            spaces = " " * (len(code) - len(code.lstrip()))
1124            spaces += " " * indent_amount
1125            return "\n".join([spaces + line for line in code.strip().split("\n")])
1126
1127        in_binders_vars = [program.env[i] for i in program.in_binders]
1128        body_code_lines = []
1129        for instruction in program.instructions:
1130            if len(instruction.out_binders) == 0:
1131                continue
1132            params = instruction.params.copy()
1133            for param_name, param in params.items():
1134                if isinstance(param, Program):
1135                    sub_program = param
1136                    fn_defs = cls.instructions_as_code(sub_program, fn_defs)
1137                    program_in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1138                    params[param_name] = f"slope.make_program({sub_program.name}, {program_in_vals})[0]"
1139                if isinstance(param, DType):
1140                    params[param_name] = f"slope.{param.name}"
1141            param_vals = ", ".join(f"{param_name}={param}" for param_name, param in params.items())
1142            in_vals = ", ".join(f"{program.env[x].name}" for x in instruction.inputs)
1143            out_vals = ", ".join(f"{program.env[z].name}" for z in instruction.out_binders)
1144            sig = program.pprint_sig(
1145                [program.env[x].symval for x in instruction.inputs],
1146                [program.env[y].symval for y in instruction.out_binders],
1147            )
1148            line = f"""{out_vals} = slope.{instruction.op.name}({in_vals}{", " if (param_vals and in_vals) else ""}{param_vals}) # {sig}"""
1149            body_code_lines += [indent(line, program.indent_amount)]
1150
1151        fn_args_str = ", ".join([f"{i.name}" for i in in_binders_vars])
1152        # fn_static_args_str = ", ".join([f"{a}={a_val}" for a, a_val in program.static_args])
1153        out_vars = [program.env[o] for o in program.outs]
1154        fn_sig = program.pprint_sig(
1155            [i.symval for i in in_binders_vars],
1156            [o.symval for o in out_vars],
1157        )
1158        head_code_line = [f"def {program.name}({fn_args_str}): # {fn_sig}"]
1159        out_str = ", ".join([f"{o.name}" for o in out_vars])
1160        tail_code_line = [indent(f"return {out_str}", program.indent_amount)]
1161        code_lines = head_code_line + body_code_lines + tail_code_line + ["\n"]
1162
1163        fn_defs[program.name] = code_lines
1164        return fn_defs
@staticmethod
def prune_instructions(instructions, outs):
1166    @staticmethod
1167    def prune_instructions(instructions, outs):
1168        graph = dict()
1169        for instruction in instructions:
1170            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1171            for parent in parent_nodes:
1172                if parent not in graph:
1173                    graph[parent] = set()
1174                for child in child_nodes:
1175                    graph[parent].add(child)
1176        visited_from_terminal = set()
1177
1178        def dfs(node, visited):
1179            visited.add(node)
1180            if node in graph:
1181                for neighbor in graph[node]:
1182                    if neighbor not in visited:
1183                        dfs(neighbor, visited)
1184
1185        for terminal_node in outs:
1186            dfs(terminal_node, visited_from_terminal)
1187        unreachable_nodes = set(graph.keys()) - visited_from_terminal
1188
1189        instructions_to_prune = []
1190        for instruction in instructions:
1191            parent_nodes, child_nodes = instruction.out_binders, instruction.inputs
1192            if any(node in unreachable_nodes for node in parent_nodes) or any(node in unreachable_nodes for node in child_nodes):
1193                instructions_to_prune += [instruction]
1194        new_instructions = [inst for inst in instructions if inst not in instructions_to_prune]
1195        if backend.LOG_PROGRAM:
1196            LI = len(instructions)
1197            LNI = len(new_instructions)
1198            DIFF = LI - LNI
1199            UN = len(unreachable_nodes)
1200            dblog(f"Before: {LI}\tAfter: {LNI}\tDiff vs Unreachables: {DIFF} == {UN} = {DIFF==UN}")
1201        return new_instructions
class ProgramType(typing.NamedTuple):
1204class ProgramType(NamedTuple):
1205    in_types: Tuple[SymbolicTensor]
1206    out_types: Tuple[SymbolicTensor]
1207
1208    def __repr__(self):
1209        in_types = ", ".join(symval.str_short() for symval in self.in_types)
1210        out_types = ", ".join(symval.str_short() for symval in self.out_types)
1211        return f"({in_types}) -> ({out_types})"

ProgramType(in_types, out_types)

ProgramType( in_types: Tuple[SymbolicTensor], out_types: Tuple[SymbolicTensor])

Create new instance of ProgramType(in_types, out_types)

in_types: Tuple[SymbolicTensor]

Alias for field number 0

out_types: Tuple[SymbolicTensor]

Alias for field number 1

Inherited Members
builtins.tuple
index
count
class Empty:
1219class Empty:
1220    pass
empty = <Empty object>
class Store:
1226class Store:
1227    val = empty
1228
1229    def set_value(self, val):
1230        assert self.val is empty
1231        self.val = val
1232
1233    def __call__(self):
1234        return self.val
val = <Empty object>
def set_value(self, val):
1229    def set_value(self, val):
1230        assert self.val is empty
1231        self.val = val
class NodeType(typing.NamedTuple):
1237class NodeType(NamedTuple):
1238    name: str
1239    flatten: Callable
1240    unflatten: Callable

NodeType(name, flatten, unflatten)

NodeType(name: str, flatten: Callable, unflatten: Callable)

Create new instance of NodeType(name, flatten, unflatten)

name: str

Alias for field number 0

flatten: Callable

Alias for field number 1

unflatten: Callable

Alias for field number 2

Inherited Members
builtins.tuple
index
count
class TreeDef(typing.NamedTuple):
1243class TreeDef(NamedTuple):
1244    node_type: NodeType
1245    node_metadata: Hashable
1246    child_treedefs: Tuple["TreeDef", ...]
1247
1248    def __repr__(self):
1249        ret = self.tree_repr(self)
1250        return ret
1251
1252    def tree_repr(self, tree, indent="  ", prefix="", last=True):
1253        ret = ""
1254
1255        def _tree_repr(tree, indent, prefix, last):
1256            nonlocal ret
1257            if isinstance(tree, TreeDef):
1258                ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n'
1259                for i, item in enumerate(tree.child_treedefs):
1260                    new_prefix = prefix + (indent if not last else "   ")
1261                    new_last = i == len(tree.child_treedefs) - 1
1262                    _tree_repr(item, indent, new_prefix, new_last)
1263            else:
1264                ret += f'{prefix} {("└─" if last else "├─")} {tree}\n'
1265
1266        _tree_repr(tree, indent="  ", prefix="", last=True)
1267        return ret
1268
1269    @property
1270    def num_leaves(self):
1271        def get_num_leaves(x):
1272            if isinstance(x, Leaf):
1273                return 1
1274            else:
1275                return sum(get_num_leaves(sub_x) for sub_x in x.child_treedefs)
1276
1277        return sum(get_num_leaves(x) for x in self.child_treedefs)

TreeDef(node_type, node_metadata, child_treedefs)

TreeDef( node_type: NodeType, node_metadata: Hashable, child_treedefs: Tuple[ForwardRef('TreeDef'), ...])

Create new instance of TreeDef(node_type, node_metadata, child_treedefs)

node_type: NodeType

Alias for field number 0

node_metadata: Hashable

Alias for field number 1

child_treedefs: Tuple[TreeDef, ...]

Alias for field number 2

def tree_repr(self, tree, indent=' ', prefix='', last=True):
1252    def tree_repr(self, tree, indent="  ", prefix="", last=True):
1253        ret = ""
1254
1255        def _tree_repr(tree, indent, prefix, last):
1256            nonlocal ret
1257            if isinstance(tree, TreeDef):
1258                ret += f'{prefix} {("└─" if last else "├─")} {tree.node_type.name}\n'
1259                for i, item in enumerate(tree.child_treedefs):
1260                    new_prefix = prefix + (indent if not last else "   ")
1261                    new_last = i == len(tree.child_treedefs) - 1
1262                    _tree_repr(item, indent, new_prefix, new_last)
1263            else:
1264                ret += f'{prefix} {("└─" if last else "├─")} {tree}\n'
1265
1266        _tree_repr(tree, indent="  ", prefix="", last=True)
1267        return ret
num_leaves
1269    @property
1270    def num_leaves(self):
1271        def get_num_leaves(x):
1272            if isinstance(x, Leaf):
1273                return 1
1274            else:
1275                return sum(get_num_leaves(sub_x) for sub_x in x.child_treedefs)
1276
1277        return sum(get_num_leaves(x) for x in self.child_treedefs)
Inherited Members
builtins.tuple
index
count
class Leaf:
1280class Leaf:
1281    def __init__(self, val):
1282        if hasattr(val, "shape"):
1283            val = SymbolicTensor.like(val)
1284        self.val = val
1285
1286    def __repr__(self):
1287        ret = self.val.str_short() if isinstance(self.val, SymbolicTensor) else repr(self.val)
1288        return f"<Leaf: {ret}>"
1289
1290    def __hash__(self):
1291        return hash(self.val)
1292
1293    def __eq__(self, other):
1294        return True  # make TreeDef __eq__ don't care Leaf
1295        # if isinstance(other, Leaf): # TODO: test above assumption
1296        #     return self.val == other.val
Leaf(val)
1281    def __init__(self, val):
1282        if hasattr(val, "shape"):
1283            val = SymbolicTensor.like(val)
1284        self.val = val
val
class JitOutput:
1304class JitOutput:
1305    def __init__(self, program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any]):
1306        super().__init__()
1307        self.program = program
1308        self.code = code
1309        self.codegen_output = codegen_output
1310        self.fn: Callable = fn
1311        self.consts = consts
1312
1313    def __call__(self, *args, **params):
1314        args, in_tree = tree_flatten(args)
1315        args = tree_map(lambda a: a.val if isinstance(a, Tensor) else a, args)
1316        try:
1317            outs = self.fn(*args, **params)
1318            if not isinstance(outs, tuple):  # TODO: IREE FunctionInvoker destructure 1-tuple, need to undo
1319                outs = (outs,)
1320        except Exception as e:
1321            dblog(self.code, enable=backend.LOG_JIT)
1322            raise
1323        return [backend.tensor(TensorBuffer(o)) for o in outs]
JitOutput( program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any])
1305    def __init__(self, program: Program, codegen_output: CodegenOutput, fn, code: str, consts: List[Any]):
1306        super().__init__()
1307        self.program = program
1308        self.code = code
1309        self.codegen_output = codegen_output
1310        self.fn: Callable = fn
1311        self.consts = consts
program
code
codegen_output
fn: Callable
consts
class JitOp(MetaOperator):
1326class JitOp(MetaOperator):
1327    def meta_impl(self, *args, program: Program, **_):
1328        hashed_program = Hashed(program)
1329        num_consts = program.num_consts
1330        consts, args = args[:num_consts], args[num_consts:]
1331        hashed_consts = tuple(map(Hashed, consts))
1332        jit_output = backend.jit_program(hashed_program, hashed_consts)
1333        ret = jit_output(*consts, *args)
1334        return ret
1335
1336    def reorg_args(self, args, params):
1337        return args, params
1338
1339    def typecheck(self, *in_types, program: Program):
1340        program_type = typecheck_program(program)
1341        if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)):
1342            ret = "Type mismatch program.in_types vs in_types:\n"
1343            for i, j in zip(program_type.in_types, in_types):
1344                ret += f"{i}, {j}, {i == j}"
1345            raise TypeError(ret)
1346        return program_type.out_types
1347
1348    def vmap(self, dim_size, vals_in, dims_in, program: Program):
1349        program, consts = vmap_program(program, dim_size, tuple(dims_in))
1350        outs = self(*consts, *vals_in, program=program)
1351        if not isinstance(outs, tuple):
1352            outs = (outs,)
1353        return outs, [0] * len(outs)
1354
1355    def jvp(self, primals, tangents, *, program):
1356        new_program, new_consts = jvp_program(program)
1357        outs = bind(
1358            self,
1359            *new_consts,
1360            *primals,
1361            *tangents,
1362            program=new_program,
1363        )
1364        n = len(outs) // 2
1365        primals_out, tangents_out = outs[:n], outs[n:]
1366        return primals_out, tangents_out
1367
1368    def T(self, cotangents, *invals, program):
1369        undef_primals = [isinstance(x, UndefinedPrimal) for x in invals]
1370        transposed_program, new_consts = transpose_program(program, tuple(undef_primals))
1371
1372        residuals, _ = partition_list(undef_primals, invals)
1373        outs = bind(
1374            self,
1375            *new_consts,
1376            *residuals,
1377            *cotangents,
1378            program=transposed_program,
1379        )
1380        outs = iter(outs)
1381
1382        return [next(outs) if undef else None for undef in undef_primals]
1383
1384    def partial_run(self, trace, tracers, *, program):
1385        in_unknowns = [not t.pval.is_known for t in tracers]
1386        program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns)
1387        known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
1388        known_vals = [t.pval.const for t in known_tracers]
1389        outs1_res = bind(backend.jit_op, *known_vals, program=program1)
1390        outs1, res = split_list(outs1_res, len(program1.outs) - num_res)
1391        res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
1392        outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs]
1393        instruction = InstructionDraft(
1394            self,
1395            res_tracers + unknown_tracers,
1396            dict(program=program2),
1397            [v.symval for v in program2.outs],
1398            list_map(weakref.ref, outs2),
1399        )
1400        for t in outs2:
1401            t.draft = instruction
1402
1403        return merge_lists(out_unknowns, outs1, outs2)
1404
1405    def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]:
1406        program = instruction.params["program"]
1407        program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in)
1408        ins1, ins2 = partition_list(unks_in, instruction.inputs)
1409        out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders)
1410        res = [Var(v.symval) for v in program2.in_binders[:num_res]]
1411        instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res)
1412        instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2)
1413        return instruction1, instruction2, out_unknowns, res
def meta_impl(self, *args, program: Program, **_):
1327    def meta_impl(self, *args, program: Program, **_):
1328        hashed_program = Hashed(program)
1329        num_consts = program.num_consts
1330        consts, args = args[:num_consts], args[num_consts:]
1331        hashed_consts = tuple(map(Hashed, consts))
1332        jit_output = backend.jit_program(hashed_program, hashed_consts)
1333        ret = jit_output(*consts, *args)
1334        return ret
def reorg_args(self, args, params):
1336    def reorg_args(self, args, params):
1337        return args, params
def typecheck(self, *in_types, program: Program):
1339    def typecheck(self, *in_types, program: Program):
1340        program_type = typecheck_program(program)
1341        if not all(t1 == t2 for t1, t2 in zip(program_type.in_types, in_types)):
1342            ret = "Type mismatch program.in_types vs in_types:\n"
1343            for i, j in zip(program_type.in_types, in_types):
1344                ret += f"{i}, {j}, {i == j}"
1345            raise TypeError(ret)
1346        return program_type.out_types
def vmap(self, dim_size, vals_in, dims_in, program: Program):
1348    def vmap(self, dim_size, vals_in, dims_in, program: Program):
1349        program, consts = vmap_program(program, dim_size, tuple(dims_in))
1350        outs = self(*consts, *vals_in, program=program)
1351        if not isinstance(outs, tuple):
1352            outs = (outs,)
1353        return outs, [0] * len(outs)
def jvp(self, primals, tangents, *, program):
1355    def jvp(self, primals, tangents, *, program):
1356        new_program, new_consts = jvp_program(program)
1357        outs = bind(
1358            self,
1359            *new_consts,
1360            *primals,
1361            *tangents,
1362            program=new_program,
1363        )
1364        n = len(outs) // 2
1365        primals_out, tangents_out = outs[:n], outs[n:]
1366        return primals_out, tangents_out
def T(self, cotangents, *invals, program):
1368    def T(self, cotangents, *invals, program):
1369        undef_primals = [isinstance(x, UndefinedPrimal) for x in invals]
1370        transposed_program, new_consts = transpose_program(program, tuple(undef_primals))
1371
1372        residuals, _ = partition_list(undef_primals, invals)
1373        outs = bind(
1374            self,
1375            *new_consts,
1376            *residuals,
1377            *cotangents,
1378            program=transposed_program,
1379        )
1380        outs = iter(outs)
1381
1382        return [next(outs) if undef else None for undef in undef_primals]
def partial_run(self, trace, tracers, *, program):
1384    def partial_run(self, trace, tracers, *, program):
1385        in_unknowns = [not t.pval.is_known for t in tracers]
1386        program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns)
1387        known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
1388        known_vals = [t.pval.const for t in known_tracers]
1389        outs1_res = bind(backend.jit_op, *known_vals, program=program1)
1390        outs1, res = split_list(outs1_res, len(program1.outs) - num_res)
1391        res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
1392        outs2 = [PartialRunTraceTensor(trace, make_unknown_pval(v.symval), None) for v in program2.outs]
1393        instruction = InstructionDraft(
1394            self,
1395            res_tracers + unknown_tracers,
1396            dict(program=program2),
1397            [v.symval for v in program2.outs],
1398            list_map(weakref.ref, outs2),
1399        )
1400        for t in outs2:
1401            t.draft = instruction
1402
1403        return merge_lists(out_unknowns, outs1, outs2)
def partial_run_instruction( self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]:
1405    def partial_run_instruction(self, unks_in, instruction) -> Tuple[Instruction, Instruction, List[bool], List[Var]]:
1406        program = instruction.params["program"]
1407        program1, program2, out_unknowns, num_res = partial_run_program(program, unks_in)
1408        ins1, ins2 = partition_list(unks_in, instruction.inputs)
1409        out_binders1, out_binders2 = partition_list(out_unknowns, instruction.out_binders)
1410        res = [Var(v.symval) for v in program2.in_binders[:num_res]]
1411        instruction1 = Instruction(self, ins1, dict(program=program1), out_binders1 + res)
1412        instruction2 = Instruction(self, res + ins2, dict(program=program2), out_binders2)
1413        return instruction1, instruction2, out_unknowns, res
class MainTrace(typing.NamedTuple):
1416class MainTrace(NamedTuple):
1417    level: int
1418    trace_type: Type["Trace"]
1419    global_data: Optional[Any]

MainTrace(level, trace_type, global_data)

MainTrace( level: int, trace_type: Type[ForwardRef('Trace')], global_data: Optional[Any])

Create new instance of MainTrace(level, trace_type, global_data)

level: int

Alias for field number 0

trace_type: Type[Trace]

Alias for field number 1

global_data: Optional[Any]

Alias for field number 2

Inherited Members
builtins.tuple
index
count
class Trace:
1422class Trace:
1423    main: MainTrace
1424
1425    def __init__(self, main: MainTrace) -> None:
1426        self.main = main
1427
1428    def pure(self, val):
1429        raise NotImplementedError
1430
1431    def run_op(self, op, tracers, params):
1432        raise NotImplementedError
Trace(main: MainTrace)
1425    def __init__(self, main: MainTrace) -> None:
1426        self.main = main
main: MainTrace
def pure(self, val):
1428    def pure(self, val):
1429        raise NotImplementedError
def run_op(self, op, tracers, params):
1431    def run_op(self, op, tracers, params):
1432        raise NotImplementedError
class RunTrace(Trace):
1435class RunTrace(Trace):
1436    pure = lambda self, x: x
1437
1438    def run_op(self, op: Operator, args, params):
1439        if isinstance(op, MetaOperator):
1440            args, params = op.reorg_args(args, params)
1441            args, params = op.args_fixer(*args, **params)
1442            ret = op.meta_impl(*args, **params)
1443        else:
1444            fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params)
1445            # with Timing(f"RUN {op}"):ret = jit(
1446            ret = jit(
1447                fn,
1448                static_argnames=("params",),
1449                name=jit.get_jit_name(args, params, op.name),
1450            )(*args, **params)
1451
1452        return ret
1453
1454    @staticmethod
1455    @lru_cache_verbose()
1456    def get_fn(op, *symval_args, **params):
1457        def fn(*args, **params):
1458            return [op(*args, **params)]
1459
1460        return fn
def pure(self, x):
1436    pure = lambda self, x: x
def run_op(self, op: Operator, args, params):
1438    def run_op(self, op: Operator, args, params):
1439        if isinstance(op, MetaOperator):
1440            args, params = op.reorg_args(args, params)
1441            args, params = op.args_fixer(*args, **params)
1442            ret = op.meta_impl(*args, **params)
1443        else:
1444            fn = self.get_fn(op, *tuple(SymbolicTensor.like(a) for a in args), **params)
1445            # with Timing(f"RUN {op}"):ret = jit(
1446            ret = jit(
1447                fn,
1448                static_argnames=("params",),
1449                name=jit.get_jit_name(args, params, op.name),
1450            )(*args, **params)
1451
1452        return ret
def get_fn(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
Inherited Members
class SymbolicRunTrace(Trace):
1463class SymbolicRunTrace(Trace):
1464    # pure = lambda self, x: x
1465    def pure(self, val: Any) -> SymbolicTensor:
1466        return val.symval
1467
1468    def run_op(self, op, tracers, params):
1469        symvals_in = tree_map(lambda x: x.symval, tracers)
1470        symvals_out = op.typecheck(*symvals_in, **params)
1471        return symvals_out
def pure(self, val: Any) -> SymbolicTensor:
1465    def pure(self, val: Any) -> SymbolicTensor:
1466        return val.symval
def run_op(self, op, tracers, params):
1468    def run_op(self, op, tracers, params):
1469        symvals_in = tree_map(lambda x: x.symval, tracers)
1470        symvals_out = op.typecheck(*symvals_in, **params)
1471        return symvals_out
Inherited Members
class TraceTensor(Tensor):
1474class TraceTensor(Tensor):
1475    PYTHON_TYPES = {
1476        bool,
1477        int,
1478        float,
1479    }
1480    _trace: "Trace"
1481
1482    def __init__(self, *args, **kwargs):
1483        raise NotImplementedError
1484
1485    symval = property(lambda self: get_symval(self.val))
1486    dtype = property(lambda self: self.symval.dtype)
1487    shape = property(lambda self: self.symval.shape)
1488    device = property(lambda self: self.symval.device)
1489
1490    @property
1491    def val(self):
1492        raise NotImplementedError
1493
1494    def __str__(self):
1495        return repr(self)
1496
1497    def full_lower(self):
1498        return self
1499
1500    @property
1501    def ndim(self):
1502        return len(self.shape)
1503
1504    def __repr__(self):
1505        return f"{self.__class__.__name__}({repr(self.symval)})"
TraceTensor(*args, **kwargs)
1482    def __init__(self, *args, **kwargs):
1483        raise NotImplementedError
PYTHON_TYPES = {<class 'bool'>, <class 'int'>, <class 'float'>}
symval
1485    symval = property(lambda self: get_symval(self.val))
dtype
1486    dtype = property(lambda self: self.symval.dtype)
shape
1487    shape = property(lambda self: self.symval.shape)
device
1488    device = property(lambda self: self.symval.device)
val
1490    @property
1491    def val(self):
1492        raise NotImplementedError
def full_lower(self):
1497    def full_lower(self):
1498        return self
ndim
1500    @property
1501    def ndim(self):
1502        return len(self.shape)
class VMapTraceTensor(TraceTensor):
1508class VMapTraceTensor(TraceTensor):
1509    def __init__(self, trace, val, vmap_dim):
1510        self._trace = trace
1511        self._val = val
1512        self.vmap_dim = vmap_dim
1513
1514    @property
1515    def val(self):
1516        return self._val
1517
1518    @property
1519    def symval(self):
1520        symval = get_symval(self.val)
1521        if self.vmap_dim is None:
1522            return symval
1523        else:
1524            shape = list(symval.shape)
1525            del shape[self.vmap_dim]
1526            return symval.like(shape=tuple(shape))
1527
1528    def full_lower(self):
1529        if self.vmap_dim is None:
1530            return full_lower(self.val)
1531        else:
1532            return self
VMapTraceTensor(trace, val, vmap_dim)
1509    def __init__(self, trace, val, vmap_dim):
1510        self._trace = trace
1511        self._val = val
1512        self.vmap_dim = vmap_dim
vmap_dim
val
1514    @property
1515    def val(self):
1516        return self._val
symval
1518    @property
1519    def symval(self):
1520        symval = get_symval(self.val)
1521        if self.vmap_dim is None:
1522            return symval
1523        else:
1524            shape = list(symval.shape)
1525            del shape[self.vmap_dim]
1526            return symval.like(shape=tuple(shape))
def full_lower(self):
1528    def full_lower(self):
1529        if self.vmap_dim is None:
1530            return full_lower(self.val)
1531        else:
1532            return self
class VMapTrace(Trace):
1535class VMapTrace(Trace):
1536    pure = lambda self, val: VMapTraceTensor(self, val, None)
1537
1538    @property
1539    def dim_size(self):
1540        return self.main.global_data
1541
1542    def run_op(self, op, tracers, params):
1543        vals_in, bdims_in = unzip2((t.val, t.vmap_dim) for t in tracers)
1544        val_outs, bdim_outs = op.vmap(self.dim_size, vals_in, bdims_in, **params)
1545        return [VMapTraceTensor(self, x, bd) for x, bd in list_zip(val_outs, bdim_outs)]
1546
1547    @staticmethod
1548    def move_vmap_dim(x, dim_size, src: int, dst: int):
1549        if src is None:  # unsqueeze and expand
1550            target_shape = list(x.shape)
1551            target_shape.insert(dst, dim_size)
1552            unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))]
1553            x = x.reshape(tuple(unsqueeze_shape))
1554            x = x.expand(tuple(target_shape))
1555            return x
1556        elif src == dst:
1557            return x
1558        else:
1559            perm = [i for i in range(len(x.shape)) if i != src]
1560            perm.insert(dst, src)
1561            return x.permute(tuple(perm))
def pure(self, val):
1536    pure = lambda self, val: VMapTraceTensor(self, val, None)
dim_size
1538    @property
1539    def dim_size(self):
1540        return self.main.global_data
def run_op(self, op, tracers, params):
1542    def run_op(self, op, tracers, params):
1543        vals_in, bdims_in = unzip2((t.val, t.vmap_dim) for t in tracers)
1544        val_outs, bdim_outs = op.vmap(self.dim_size, vals_in, bdims_in, **params)
1545        return [VMapTraceTensor(self, x, bd) for x, bd in list_zip(val_outs, bdim_outs)]
@staticmethod
def move_vmap_dim(x, dim_size, src: int, dst: int):
1547    @staticmethod
1548    def move_vmap_dim(x, dim_size, src: int, dst: int):
1549        if src is None:  # unsqueeze and expand
1550            target_shape = list(x.shape)
1551            target_shape.insert(dst, dim_size)
1552            unsqueeze_shape = [1 if d == dst else target_shape[d] for d in range(len(target_shape))]
1553            x = x.reshape(tuple(unsqueeze_shape))
1554            x = x.expand(tuple(target_shape))
1555            return x
1556        elif src == dst:
1557            return x
1558        else:
1559            perm = [i for i in range(len(x.shape)) if i != src]
1560            perm.insert(dst, src)
1561            return x.permute(tuple(perm))
Inherited Members
class JVPTraceTensor(TraceTensor):
1564class JVPTraceTensor(TraceTensor):
1565    def __init__(self, trace, primal, tangent):
1566        self._trace = trace
1567        self.primal = primal
1568        self.tangent = tangent
1569
1570    @property
1571    def symval(self):
1572        return get_symval(self.primal)
1573
1574    @property
1575    def val(self):
1576        return self.primal
1577
1578    @property
1579    def dtype(self):
1580        return self.primal.dtype
JVPTraceTensor(trace, primal, tangent)
1565    def __init__(self, trace, primal, tangent):
1566        self._trace = trace
1567        self.primal = primal
1568        self.tangent = tangent
primal
tangent
symval
1570    @property
1571    def symval(self):
1572        return get_symval(self.primal)
val
1574    @property
1575    def val(self):
1576        return self.primal
dtype
1578    @property
1579    def dtype(self):
1580        return self.primal.dtype
class JVPTrace(Trace):
1583class JVPTrace(Trace):
1584    def pure(self, val):
1585        if isinstance(val, PartialRunTrace):
1586            val = val.pval.const
1587        return JVPTraceTensor(self, val, backend.zeros_like(val))
1588
1589    def run_op(self, op, tracers, params):
1590        primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
1591        primals_out, tangents_out = op.jvp(primals_in, tangents_in, **params)
1592        return [JVPTraceTensor(self, x, t) for x, t in list_zip(primals_out, tangents_out)]
def pure(self, val):
1584    def pure(self, val):
1585        if isinstance(val, PartialRunTrace):
1586            val = val.pval.const
1587        return JVPTraceTensor(self, val, backend.zeros_like(val))
def run_op(self, op, tracers, params):
1589    def run_op(self, op, tracers, params):
1590        primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
1591        primals_out, tangents_out = op.jvp(primals_in, tangents_in, **params)
1592        return [JVPTraceTensor(self, x, t) for x, t in list_zip(primals_out, tangents_out)]
Inherited Members
class ProgramTraceTensor(TraceTensor):
1595class ProgramTraceTensor(TraceTensor):
1596    __slots__ = ["symval"]
1597    symval: SymbolicTensor
1598
1599    def __init__(self, trace, symval):
1600        self._trace = trace
1601        self.symval = symval
ProgramTraceTensor(trace, symval)
1599    def __init__(self, trace, symval):
1600        self._trace = trace
1601        self.symval = symval
symval: SymbolicTensor
class ProgramTrace(Trace):
1604class ProgramTrace(Trace):
1605    @property
1606    def builder(self):
1607        return self.main.global_data
1608
1609    def new_arg(self, symval) -> ProgramTraceTensor:
1610        symval = SymbolicTensor.like(symval)
1611        tracer = self.builder.new_tracer(self, symval)
1612        self.builder.tracer_to_var[id(tracer)] = Var(symval)
1613
1614        return tracer
1615
1616    def pure(self, val: Any) -> ProgramTraceTensor:
1617        # get_or_make_const_tracer
1618        tracer = self.builder.const_tracers.get(id(val))
1619        if tracer is None:
1620            tracer = self.builder.new_tracer(self, get_symval(val))
1621            self.builder.add_const(tracer, val)
1622        # print(self.builder.const_tracers)
1623        return tracer
1624
1625    def run_op(self, op, tracers, params):
1626        symvals_in = tree_map(lambda x: x.symval, tracers)
1627        symvals_out = op.typecheck(*symvals_in, **params)
1628
1629        out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out]
1630        inputs = [self.builder.getvar(t) for t in tracers]
1631        outvars = [self.builder.add_var(t) for t in out_tracers]
1632
1633        self.builder.add_instruction(Instruction(op, inputs, params, outvars))
1634        return out_tracers
builder
1605    @property
1606    def builder(self):
1607        return self.main.global_data
def new_arg(self, symval) -> ProgramTraceTensor:
1609    def new_arg(self, symval) -> ProgramTraceTensor:
1610        symval = SymbolicTensor.like(symval)
1611        tracer = self.builder.new_tracer(self, symval)
1612        self.builder.tracer_to_var[id(tracer)] = Var(symval)
1613
1614        return tracer
def pure(self, val: Any) -> ProgramTraceTensor:
1616    def pure(self, val: Any) -> ProgramTraceTensor:
1617        # get_or_make_const_tracer
1618        tracer = self.builder.const_tracers.get(id(val))
1619        if tracer is None:
1620            tracer = self.builder.new_tracer(self, get_symval(val))
1621            self.builder.add_const(tracer, val)
1622        # print(self.builder.const_tracers)
1623        return tracer
def run_op(self, op, tracers, params):
1625    def run_op(self, op, tracers, params):
1626        symvals_in = tree_map(lambda x: x.symval, tracers)
1627        symvals_out = op.typecheck(*symvals_in, **params)
1628
1629        out_tracers = [self.builder.new_tracer(self, a) for a in symvals_out]
1630        inputs = [self.builder.getvar(t) for t in tracers]
1631        outvars = [self.builder.add_var(t) for t in out_tracers]
1632
1633        self.builder.add_instruction(Instruction(op, inputs, params, outvars))
1634        return out_tracers
Inherited Members
class ProgramBuilder:
1637class ProgramBuilder:
1638    instructions: List[Instruction]
1639    tracer_to_var: Dict[int, Var]
1640    const_tracers: Dict[int, TraceTensor]
1641    constvals: Dict[Var, Any]
1642    tracers: List[ProgramTraceTensor]
1643
1644    def __init__(self):
1645        self.instructions = []
1646        self.tracer_to_var = {}
1647        self.const_tracers = {}
1648        self.constvals = {}
1649        self.tracers = []
1650
1651    def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor:
1652        tracer = ProgramTraceTensor(trace, symval)
1653        self.tracers += [tracer]
1654        return tracer
1655
1656    def add_instruction(self, instruction: Instruction) -> None:
1657        self.instructions += [instruction]
1658
1659    def add_var(self, tracer: ProgramTraceTensor) -> Var:
1660        assert id(tracer) not in self.tracer_to_var
1661        var = self.tracer_to_var[id(tracer)] = Var(tracer.symval)
1662        return var
1663
1664    def getvar(self, tracer: ProgramTraceTensor) -> Var:
1665        var = self.tracer_to_var.get(id(tracer))
1666        assert var is not None
1667        return var
1668
1669    def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var:
1670        var = self.add_var(tracer)
1671        self.const_tracers[id(val)] = tracer
1672        self.constvals[var] = val
1673        return var
1674
1675    def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]:
1676        constvars, constvals = unzip2(self.constvals.items())
1677        t2v = lambda t: self.tracer_to_var[id(t)]
1678        in_binders = constvars + [t2v(t) for t in in_tracers]
1679        out_vars = [t2v(t) for t in out_tracers]
1680        program = Program(
1681            in_binders,
1682            self.instructions,
1683            out_vars,
1684            len(constvals),
1685            static_args,
1686            name,
1687        )
1688        typecheck_program(program)
1689        program, constvals = self._inline_literals(program, constvals)
1690        typecheck_program(program)
1691        # dblog(program, enable=backend.LOG_PROGRAM)
1692        return program, constvals
1693
1694    def _inline_literals(self, program: Program, consts: List[Any]) -> Tuple[Program, List[Any]]:
1695        const_binders, other_binders = split_list(program.in_binders, len(consts))
1696        scalars = [type(x) in TraceTensor.PYTHON_TYPES and not get_symval(x).shape for x in consts]
1697        new_const_binders, lit_binders = partition_list(scalars, const_binders)
1698        new_consts, lit_vals = partition_list(scalars, consts)
1699        literals = dict(list_zip(lit_binders, list_map(Lit, lit_vals)))
1700        new_outs = [literals.get(x, x) for x in program.outs]
1701        new_instructions = [
1702            Instruction(
1703                instruction.op,
1704                [literals.get(x, x) for x in instruction.inputs],
1705                instruction.params,
1706                instruction.out_binders,
1707            )
1708            for instruction in program.instructions
1709        ]
1710        new_program = Program(
1711            new_const_binders + other_binders,
1712            new_instructions,
1713            new_outs,
1714            len(new_consts),
1715            program.static_args,
1716            program.name,
1717        )
1718        return new_program, tuple(new_consts)
1719
1720    def get_current_scope_info(self):
1721        current_frame = inspect.currentframe()
1722        current_function_name = current_frame.f_code.co_name
1723        current_module_name = inspect.getmodulename(current_frame.f_code.co_filename)
1724        current_class_name = None
1725        for frame_info in inspect.getouterframes(current_frame):
1726            print(frame_info)
1727            frame_locals = frame_info.frame.f_locals
1728            print(frame_locals)
1729            if "self" in frame_locals:
1730                current_class_name = frame_locals["self"].__class__.__name__
1731                break
1732        return {
1733            "Function": current_function_name,
1734            "Module": current_module_name,
1735            "Class": current_class_name,
1736        }
instructions: List[Instruction]
tracer_to_var: Dict[int, Var]
const_tracers: Dict[int, TraceTensor]
constvals: Dict[Var, Any]
tracers: List[ProgramTraceTensor]
def new_tracer( self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor:
1651    def new_tracer(self, trace: ProgramTrace, symval: SymbolicTensor) -> ProgramTraceTensor:
1652        tracer = ProgramTraceTensor(trace, symval)
1653        self.tracers += [tracer]
1654        return tracer
def add_instruction(self, instruction: Instruction) -> None:
1656    def add_instruction(self, instruction: Instruction) -> None:
1657        self.instructions += [instruction]
def add_var(self, tracer: ProgramTraceTensor) -> Var:
1659    def add_var(self, tracer: ProgramTraceTensor) -> Var:
1660        assert id(tracer) not in self.tracer_to_var
1661        var = self.tracer_to_var[id(tracer)] = Var(tracer.symval)
1662        return var
def getvar(self, tracer: ProgramTraceTensor) -> Var:
1664    def getvar(self, tracer: ProgramTraceTensor) -> Var:
1665        var = self.tracer_to_var.get(id(tracer))
1666        assert var is not None
1667        return var
def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var:
1669    def add_const(self, tracer: ProgramTraceTensor, val: Any) -> Var:
1670        var = self.add_var(tracer)
1671        self.const_tracers[id(val)] = tracer
1672        self.constvals[var] = val
1673        return var
def build( self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]:
1675    def build(self, in_tracers: Any, out_tracers: Any, static_args, name) -> Tuple[Program, List[Any]]:
1676        constvars, constvals = unzip2(self.constvals.items())
1677        t2v = lambda t: self.tracer_to_var[id(t)]
1678        in_binders = constvars + [t2v(t) for t in in_tracers]
1679        out_vars = [t2v(t) for t in out_tracers]
1680        program = Program(
1681            in_binders,
1682            self.instructions,
1683            out_vars,
1684            len(constvals),
1685            static_args,
1686            name,
1687        )
1688        typecheck_program(program)
1689        program, constvals = self._inline_literals(program, constvals)
1690        typecheck_program(program)
1691        # dblog(program, enable=backend.LOG_PROGRAM)
1692        return program, constvals
def get_current_scope_info(self):
1720    def get_current_scope_info(self):
1721        current_frame = inspect.currentframe()
1722        current_function_name = current_frame.f_code.co_name
1723        current_module_name = inspect.getmodulename(current_frame.f_code.co_filename)
1724        current_class_name = None
1725        for frame_info in inspect.getouterframes(current_frame):
1726            print(frame_info)
1727            frame_locals = frame_info.frame.f_locals
1728            print(frame_locals)
1729            if "self" in frame_locals:
1730                current_class_name = frame_locals["self"].__class__.__name__
1731                break
1732        return {
1733            "Function": current_function_name,
1734            "Module": current_module_name,
1735            "Class": current_class_name,
1736        }
class UndefinedPrimal(typing.NamedTuple):
1739class UndefinedPrimal(NamedTuple):
1740    symval: SymbolicTensor
1741
1742    @property
1743    def shape(self):
1744        return self.symval.shape
1745
1746    @property
1747    def dtype(self):
1748        return self.symval.dtype
1749
1750    @property
1751    def device(self):
1752        return self.symval.device
1753
1754    @property
1755    def ndim(self):
1756        return self.symval.ndim
1757
1758    def __repr__(self):
1759        return f"<UndefinedPrimal: symval={self.symval}>"
1760
1761    str_short = __repr__

UndefinedPrimal(symval,)

UndefinedPrimal(symval: SymbolicTensor)

Create new instance of UndefinedPrimal(symval,)

symval: SymbolicTensor

Alias for field number 0

shape
1742    @property
1743    def shape(self):
1744        return self.symval.shape
dtype
1746    @property
1747    def dtype(self):
1748        return self.symval.dtype
device
1750    @property
1751    def device(self):
1752        return self.symval.device
ndim
1754    @property
1755    def ndim(self):
1756        return self.symval.ndim
def str_short(self):
1758    def __repr__(self):
1759        return f"<UndefinedPrimal: symval={self.symval}>"

Return repr(self).

Inherited Members
builtins.tuple
index
count
class PartialValue(typing.NamedTuple):
1764class PartialValue(NamedTuple):
1765    symval: SymbolicTensor
1766    const: Optional[Any]
1767
1768    is_known = property(lambda self: self.const is not None)
1769    is_unknown = property(lambda self: self.const is None)

PartialValue(symval, const)

PartialValue(symval: SymbolicTensor, const: Optional[Any])

Create new instance of PartialValue(symval, const)

symval: SymbolicTensor

Alias for field number 0

const: Optional[Any]

Alias for field number 1

is_known
1768    is_known = property(lambda self: self.const is not None)
is_unknown
1769    is_unknown = property(lambda self: self.const is None)
Inherited Members
builtins.tuple
index
count
class LambdaBindingDraft(typing.NamedTuple):
1772class LambdaBindingDraft(NamedTuple):
1773    pass

LambdaBindingDraft()

LambdaBindingDraft()

Create new instance of LambdaBindingDraft()

Inherited Members
builtins.tuple
index
count
class ConstDraft(typing.NamedTuple):
1776class ConstDraft(NamedTuple):
1777    val: Any

ConstDraft(val,)

ConstDraft(val: Any)

Create new instance of ConstDraft(val,)

val: Any

Alias for field number 0

Inherited Members
builtins.tuple
index
count
class InstructionDraft(typing.NamedTuple):
1780class InstructionDraft(NamedTuple):
1781    prim: Operator
1782    tracers_in: List["PartialRunTraceTensor"]
1783    params: Dict[str, Any]
1784    symvals_out: List[SymbolicTensor]
1785    tracer_refs_out: List[weakref.ReferenceType["PartialRunTraceTensor"]]

InstructionDraft(prim, tracers_in, params, symvals_out, tracer_refs_out)

InstructionDraft( prim: Operator, tracers_in: List[ForwardRef('PartialRunTraceTensor')], params: Dict[str, Any], symvals_out: List[SymbolicTensor], tracer_refs_out: List[weakref.ReferenceType['PartialRunTraceTensor']])

Create new instance of InstructionDraft(prim, tracers_in, params, symvals_out, tracer_refs_out)

prim: Operator

Alias for field number 0

tracers_in: List[PartialRunTraceTensor]

Alias for field number 1

params: Dict[str, Any]

Alias for field number 2

symvals_out: List[SymbolicTensor]

Alias for field number 3

tracer_refs_out: List[weakref.ReferenceType[PartialRunTraceTensor]]

Alias for field number 4

Inherited Members
builtins.tuple
index
count
ProgramDraft = typing.Union[LambdaBindingDraft, ConstDraft, InstructionDraft]
class PartialRunTraceTensor(TraceTensor):
1791class PartialRunTraceTensor(TraceTensor):
1792    def __init__(self, trace, pval, draft):
1793        self._trace = trace
1794        self.pval = pval
1795        self.draft = draft
1796
1797    symval = property(lambda self: self.pval.symval)
1798    val = property(lambda self: self.pval.const)
1799
1800    def full_lower(self):
1801        if self.pval.is_known:
1802            return full_lower(self.pval.const)
1803        return self
PartialRunTraceTensor(trace, pval, draft)
1792    def __init__(self, trace, pval, draft):
1793        self._trace = trace
1794        self.pval = pval
1795        self.draft = draft
pval
draft
symval
1797    symval = property(lambda self: self.pval.symval)
val
1798    val = property(lambda self: self.pval.const)
def full_lower(self):
1800    def full_lower(self):
1801        if self.pval.is_known:
1802            return full_lower(self.pval.const)
1803        return self
class PartialRunTrace(Trace):
1806class PartialRunTrace(Trace):
1807    def new_arg(self, pval: PartialValue) -> Any:
1808        return PartialRunTraceTensor(self, pval, LambdaBindingDraft())
1809
1810    def pure(self, val: Any) -> PartialRunTraceTensor:
1811        return PartialRunTraceTensor(self, make_known_pval(val), None)
1812
1813    def instantiate_const(self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor:
1814        if tracer.pval.is_unknown:
1815            return tracer
1816        else:
1817            pval = make_unknown_pval(SymbolicTensor.like(tracer.symval))
1818            return PartialRunTraceTensor(self, pval, ConstDraft(tracer.pval.const))
1819
1820    def run_op(self, op, tracers, params):
1821        is_knowns = tuple(t.pval.is_known for t in tracers)
1822
1823        if all(is_knowns):
1824            return bind(op, *list_map(full_lower, tracers), **params)
1825        return op.partial_run(self, tracers, **params)
def new_arg(self, pval: PartialValue) -> Any:
1807    def new_arg(self, pval: PartialValue) -> Any:
1808        return PartialRunTraceTensor(self, pval, LambdaBindingDraft())
def pure(self, val: Any) -> PartialRunTraceTensor:
1810    def pure(self, val: Any) -> PartialRunTraceTensor:
1811        return PartialRunTraceTensor(self, make_known_pval(val), None)
def instantiate_const( self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor:
1813    def instantiate_const(self, tracer: PartialRunTraceTensor) -> PartialRunTraceTensor:
1814        if tracer.pval.is_unknown:
1815            return tracer
1816        else:
1817            pval = make_unknown_pval(SymbolicTensor.like(tracer.symval))
1818            return PartialRunTraceTensor(self, pval, ConstDraft(tracer.pval.const))
def run_op(self, op, tracers, params):
1820    def run_op(self, op, tracers, params):
1821        is_knowns = tuple(t.pval.is_known for t in tracers)
1822
1823        if all(is_knowns):
1824            return bind(op, *list_map(full_lower, tracers), **params)
1825        return op.partial_run(self, tracers, **params)
Inherited Members
trace_stack: List[MainTrace] = [MainTrace(level=0, trace_type=<class 'RunTrace'>, global_data=None)]
stashed_trace: Optional[MainTrace] = None
class UndefBackend:
1833class UndefBackend:
1834    def __getattr__(self, attr):
1835        raise NotImplementedError("Backend not init yet with slope.core.set_backend(backend)")
backend = <slope.backends.onnxruntime.ONNXRuntimeBackend object>
def set_backend(name, where='slope.backends'):
1841def set_backend(name, where="slope.backends"):
1842    global backend
1843    backend = importlib.import_module(f"{where}.{name}").backend
1844    import slope.nn as nn
1845
1846    # backend.register_node(nn.Module, nn.Module.flatten, nn.Module.unflatten, "Module")
1847
1848    dblog(f"slope backend is {backend}", enable=backend.LOG_INIT)
def stack_str():
1851def stack_str():
1852    ret = ""
1853    for trace in trace_stack:
1854        ret += f"{trace.level}: {trace.trace_type.__name__}\t{trace.global_data=}\n"
1855    return ret
def make_known_pval(val: Any):
1858def make_known_pval(val: Any):
1859    return PartialValue(get_symval(val), val)
def make_unknown_pval(symval: SymbolicTensor):
1862def make_unknown_pval(symval: SymbolicTensor):
1863    return PartialValue(symval, None)
def get_symval(x):
1866def get_symval(x):
1867    if isinstance(x, TraceTensor):
1868        return x.symval
1869    elif type(x) in TraceTensor.PYTHON_TYPES:
1870        return backend.tensor(x)
1871    elif isinstance(x, Tensor):
1872        return x
1873    elif isinstance(x, SymbolicTensor):
1874        return x
1875    else:
1876        raise TypeError(type(x))
def tree_flatten(x: Any) -> Any:
1879def tree_flatten(x: Any) -> Any:
1880    def _tree_flatten(x_: Any) -> Tuple[Iterable, Union[TreeDef, Leaf]]:
1881        node_type = None
1882        for k in backend.node_types.keys():
1883            if isinstance(x_, k):
1884                node_type = backend.node_types[k]
1885
1886        if node_type is not None:
1887            node_metadata, children = node_type.flatten(x_)
1888            children_flat, child_trees = unzip2(list_map(_tree_flatten, children))
1889            children_iter = itertools.chain.from_iterable(children_flat)
1890            treedef = TreeDef(node_type, node_metadata, tuple(child_trees))
1891            return children_iter, treedef
1892        else:
1893            return (x_,), Leaf(x_)
1894
1895    children_iter, treedef = _tree_flatten(x)
1896    return tuple(children_iter), treedef
def tree_unflatten(treedef: TreeDef, xs: Tuple[Any]) -> Any:
1899def tree_unflatten(treedef: TreeDef, xs: Tuple[Any]) -> Any:
1900    def _tree_unflatten(treedef_: TreeDef, xs_: Iterator) -> Any:
1901        if isinstance(treedef_, Leaf):
1902            dblog(f"    tree leaf found: {xs_}\n", enable=backend.LOG_TREE)
1903            return next(xs_)
1904        else:
1905            dblog(f"    now\n  {treedef_}", enable=backend.LOG_TREE)
1906            children = (_tree_unflatten(t, xs_) for t in treedef_.child_treedefs)
1907            dblog(f"{children=}\n", enable=backend.LOG_TREE)
1908            return treedef_.node_type.unflatten(treedef_.node_metadata, children)
1909
1910    dblog(f"unflattening {treedef}", enable=backend.LOG_TREE)
1911    return _tree_unflatten(treedef, iter(xs))
1912    # with Timing(f"\nTREE:\n{treedef}"):
1913    #     ret = _tree_unflatten(treedef, iter(xs))
1914    # return ret
def tree_transpose( outer_treedef: TreeDef, inner_treedef: TreeDef, tree_to_transpose: Any) -> Any:
1917def tree_transpose(
1918    outer_treedef: TreeDef,
1919    inner_treedef: TreeDef,
1920    tree_to_transpose: Any,
1921) -> Any:
1922    flat, treedef = tree_flatten(tree_to_transpose)
1923    inner_size = inner_treedef.num_leaves
1924    outer_size = outer_treedef.num_leaves
1925    if treedef.num_leaves != (inner_size * outer_size):
1926        raise TypeError
1927    iter_flat = iter(flat)
1928    lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
1929    permuted_lol = zip(*lol)
1930    subtrees = map(partial(tree_unflatten, outer_treedef), permuted_lol)
1931    return tree_unflatten(inner_treedef, subtrees)
def flatten_fn(f, in_tree, *, has_aux=False):
1934def flatten_fn(f, in_tree, *, has_aux=False):
1935    store = Store()
1936
1937    def flat_fn(*args_flat, **params):
1938        tree_args = tree_unflatten(in_tree, args_flat)
1939        out = f(*tree_args, **params)
1940        if has_aux:
1941            out, aux = out
1942        out_flat, out_tree = tree_flatten(out)
1943        store.set_value(out_tree)
1944        return (out_flat, aux) if has_aux else out_flat
1945
1946    return flat_fn, store
def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any:
1949def tree_map(f: Callable[..., Any], tree, *rest, out_leaf=False) -> Any:
1950    leaves, treedef = tree_flatten(tree)
1951    if len(rest) == 0:
1952        out_tree_flat = tuple(f(leaf) for leaf in leaves)
1953        out_tree = tree_unflatten(treedef, out_tree_flat)
1954    else:
1955        all_leaves = [leaves]
1956        for t in rest:
1957            t_leaves, t_treedef = tree_flatten(t)
1958            assert t_treedef == treedef
1959            all_leaves += [t_leaves]
1960
1961        out_tree_flat = tuple(f(*xs) for xs in zip(*all_leaves))
1962        out_tree = tree_unflatten(treedef, out_tree_flat)
1963    ret = out_tree
1964    if out_leaf:
1965        ret = (ret, tree_flatten(out_tree_flat[0]))
1966    return ret
@contextmanager
def new_main_trace(trace_type: Type[Trace], global_data=None):
1969@contextmanager
1970def new_main_trace(trace_type: Type["Trace"], global_data=None):
1971    global trace_stack
1972    level = len(trace_stack)
1973    main = MainTrace(level, trace_type, global_data)
1974    trace_stack += [main]
1975
1976    try:
1977        yield main
1978    finally:
1979        trace_stack.pop()
def bind(op, *args, **params):
1982def bind(op, *args, **params):
1983    top_trace = find_top_trace(args)
1984    tracers = tuple([full_raise(top_trace, arg) for arg in args])
1985    outs = top_trace.run_op(op, tracers, params)
1986    lowered = tuple([full_lower(out) for out in outs])
1987    return lowered
def find_top_trace(xs) -> Trace:
1990def find_top_trace(xs) -> Trace:
1991    arrs = []
1992
1993    def get_arr_from_seq(seq):
1994        nonlocal arrs
1995        for x in seq:
1996            if type(x) in (tuple, list):
1997                get_arr_from_seq(x)
1998            elif isinstance(x, TraceTensor):
1999                arrs += [x]
2000
2001    get_arr_from_seq(xs)
2002    arrs = tuple(arrs)
2003    top_main = max(
2004        (x._trace.main for x in arrs),
2005        default=trace_stack[0],
2006        key=operator_py.attrgetter("level"),
2007    )
2008    if stashed_trace and stashed_trace.level > top_main.level:
2009        top_main = stashed_trace
2010    return top_main.trace_type(top_main)
def full_raise(trace: Trace, val: Any) -> TraceTensor:
2013def full_raise(trace: Trace, val: Any) -> TraceTensor:
2014    if not isinstance(val, TraceTensor):
2015        return trace.pure(val)
2016    level = trace.main.level
2017    if val._trace.main is trace.main:
2018        return val
2019    elif val._trace.main.level < level:
2020        return trace.pure(val)
2021    elif val._trace.main.level > level:
2022        raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
2023    else:
2024        raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
def full_lower(val: Any):
2027def full_lower(val: Any):
2028    if isinstance(val, TraceTensor):
2029        return val.full_lower()
2030    elif type(val) in (list, tuple):
2031        return tuple(full_lower(v) for v in val)
2032    else:
2033        return val
def typecheck_program(program: Program) -> ProgramType:
2036def typecheck_program(program: Program) -> ProgramType:
2037    env: Set[Var] = set()
2038
2039    for v in program.in_binders:
2040        if v in env:
2041            raise TypeError
2042        env.add(v)
2043
2044    for instruction in program.instructions:
2045        in_types = [typecheck_atom(env, x) for x in instruction.inputs]
2046        out_types = instruction.op.typecheck(*in_types, **instruction.params)
2047        for out_binder, out_type in list_zip(instruction.out_binders, out_types):
2048            if not out_type == out_binder.symval:
2049                raise TypeError
2050        for out_binder in instruction.out_binders:
2051            if out_binder in env:
2052                raise TypeError
2053            env.add(out_binder)
2054
2055    in_types = [v.symval for v in program.in_binders]
2056    out_types = [typecheck_atom(env, x) for x in program.outs]
2057    return ProgramType(tuple(in_types), tuple(out_types))
def typecheck_atom( env: Set[Var], x: Union[Var, Lit]) -> SymbolicTensor:
2060def typecheck_atom(env: Set[Var], x: Atom) -> SymbolicTensor:
2061    if isinstance(x, Var):
2062        if x not in env:
2063            raise TypeError("unbound variable")
2064        return x.symval
2065    elif isinstance(x, Lit):
2066        return get_symval(x.val)
2067    else:
2068        assert False
def run_program(program: Program, args: List[Any]) -> List[Any]:
2071def run_program(program: Program, args: List[Any]) -> List[Any]:
2072    env: Dict[Var, Any] = {}
2073
2074    def read(x: Atom) -> Any:
2075        return env[x] if type(x) is Var else x.val
2076
2077    def write(v: Var, val: Any) -> None:
2078        assert v not in env  # single-assignment
2079        env[v] = val
2080
2081    list_map(write, program.in_binders, args)
2082    for instruction in program.instructions:
2083        in_vals = list_map(read, instruction.inputs)
2084        outs = bind(instruction.op, *in_vals, **instruction.params)
2085        list_map(write, instruction.out_binders, outs)
2086    return list_map(read, program.outs)
def program_as_fun(program: Program):
2089def program_as_fun(program: Program):
2090    return lambda *args: run_program(program, args)
def vmap_flat(f, in_dim, out_dim, dim_size, *args):
2093def vmap_flat(f, in_dim, out_dim, dim_size, *args):
2094    if dim_size is None:
2095        dims = set([x.shape[d] for x, d in list_zip(args, in_dim) if d is not None])
2096        assert len(dims) == 1
2097        (dim_size,) = dims
2098    with new_main_trace(VMapTrace, dim_size) as main:
2099        trace = VMapTrace(main)
2100        tracers_in = [VMapTraceTensor(trace, x, dim) if dim is not None else x for x, dim in list_zip(args, in_dim)]
2101        outs = f(*tracers_in)
2102        tracers_out = [full_raise(trace, out) for out in outs]
2103        vals_out, y_vmap_dims = unzip2((t.val, t.vmap_dim) for t in tracers_out)
2104    ret = [VMapTrace.move_vmap_dim(val_out, dim_size, bdim, out_dim) for val_out, bdim, out_dim in zip(vals_out, y_vmap_dims, out_dim)]
2105    return ret
def vmap(f, in_dim=0, out_dim=0, dim_size=None):
2108def vmap(f, in_dim=0, out_dim=0, dim_size=None):
2109    def batched_f(*args):
2110        nonlocal in_dim, out_dim, dim_size
2111        args_flat, in_tree = tree_flatten(args)
2112        in_dim = (in_dim,) * len(args) if isinstance(in_dim, int) else in_dim
2113        out_dim = (out_dim,) * len(args) if isinstance(out_dim, int) else out_dim
2114        in_dim_flat, in_dim_tree = tree_flatten(in_dim)
2115        out_dim_flat, out_dim_tree = tree_flatten(out_dim)
2116        if not (in_tree == in_dim_tree == out_dim_tree):
2117            raise TypeError(f"\n{in_tree}\n!=\n{in_dim_tree}!=\n{out_dim_tree}")
2118        f_flat, out_tree_store = flatten_fn(f, in_tree)
2119        # if len(args_flat) > len(in_dim_flat):
2120        #     in_dim_flat = (in_dim[0],) * len(args_flat)
2121        outs_flat = vmap_flat(f_flat, in_dim_flat, out_dim_flat, dim_size, *args_flat)
2122        return tree_unflatten(out_tree_store(), outs_flat)
2123
2124    return batched_f
def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args):
2127def jvp_flat(f, primals, tangents, *, has_aux, global_data, **static_args):
2128    with new_main_trace(JVPTrace, global_data) as main:
2129        trace = JVPTrace(main)
2130        tracers_in = [JVPTraceTensor(trace, x, t) for x, t in list_zip(primals, tangents)]
2131        jvp_flat_ret = f(*tracers_in, **static_args)
2132        if has_aux:
2133            (outs, aux) = jvp_flat_ret
2134            # aux_ = aux
2135            aux = tree_map(lambda x: x.primal, aux)
2136            # aux = tree_map(lambda x: x.full_lower(), aux)
2137            #
2138        else:
2139            outs = jvp_flat_ret
2140        tracers_out = [full_raise(trace, out) for out in outs]
2141        primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
2142    return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
def jvp( f, primals, tangents, *, has_aux=False, global_data=None, **static_args):
2145def jvp(f, primals, tangents, *, has_aux=False, global_data=None, **static_args):
2146    primals_flat, in_tree = tree_flatten(primals)
2147    tangents_flat, in_tree2 = tree_flatten(tangents)
2148    for p, t in zip(primals_flat, tangents_flat):
2149        assert p.shape == t.shape, f"{p.shape=} != {t.shape=}"
2150        assert p.dtype == t.dtype, f"{p.dtype=} != {t.dtype=}"
2151        assert p.device == t.device, f"{p.device=} != {t.device=}"
2152    if in_tree != in_tree2:
2153        raise TypeError
2154    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2155    jvp_ret = jvp_flat(
2156        f,
2157        primals_flat,
2158        tangents_flat,
2159        has_aux=has_aux,
2160        global_data=global_data,
2161        **static_args,
2162    )
2163    if has_aux:
2164        (primals_out_flat, tangents_out_flat), aux = jvp_ret
2165    else:
2166        (primals_out_flat, tangents_out_flat) = jvp_ret
2167    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2168    tangents_out = tree_unflatten(out_tree_store(), tangents_out_flat)
2169    return ((primals_out, tangents_out), aux) if has_aux else (primals_out, tangents_out)
def jacfwd(f, argnums=0, has_aux=False):
2172def jacfwd(f, argnums=0, has_aux=False):
2173    def jvp_fn(x):
2174        return jvp(f, x, (backend.eye(len(x)),), has_aux=has_aux)
2175    return vmap(jvp_fn, in_dim=argnums)
def jacrev(f, argnums=0, has_aux=False):
2177def jacrev(f, argnums=0, has_aux=False):
2178    def grad_f(x):
2179        return grad(lambda x: f(x) @ backend.eye(f(x).shape[0]), argnums=argnums, has_aux=has_aux)(x)
2180    return vmap(grad_f)
def hessian(fn, argnums=0, has_aux=False):
2188def hessian(fn, argnums=0, has_aux=False):
2189    return jacrev(jacrev(fn, argnums=argnums, has_aux=has_aux))
@contextmanager
def stash_trace(main: MainTrace):
2191@contextmanager
2192def stash_trace(main: MainTrace):
2193    global stashed_trace
2194    prev_stashed_trace, stashed_trace = stashed_trace, main
2195    try:
2196        yield
2197    finally:
2198        stashed_trace = prev_stashed_trace
@contextmanager
def symbolic_run():
2201@contextmanager
2202def symbolic_run():
2203    global trace_stack
2204    level = len(trace_stack)
2205    main = MainTrace(level, SymbolicRunTrace, global_data=None)
2206    trace_stack += [main]
2207    global stashed_trace
2208    prev_stashed_trace, stashed_trace = stashed_trace, main
2209    try:
2210        yield
2211    finally:
2212        stashed_trace = prev_stashed_trace
2213        trace_stack.pop()
def make_program(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
def vmap_program(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
def jvp_program(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
def partial_run_flat( f: Callable, pvals_in: List[PartialValue], has_aux, global_data=None) -> Tuple[Program, List[PartialValue], List[Any]]:
2271def partial_run_flat(
2272    f: Callable, pvals_in: List["PartialValue"], has_aux, global_data=None
2273) -> Tuple[Program, List["PartialValue"], List[Any]]:
2274    with new_main_trace(PartialRunTrace, global_data) as main:
2275        trace = PartialRunTrace(main)
2276        tracers_in = [trace.new_arg(pval) for pval in pvals_in]
2277        outs = f(*tracers_in)
2278        if has_aux:
2279            outs, aux = outs
2280        tracers_out = [full_raise(trace, out) for out in outs]
2281        pvals_out = [t.pval for t in tracers_out]
2282        unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
2283        unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
2284        program, consts = tracers_to_program(unk_tracers_in, unk_tracers_out)
2285
2286    return (program, pvals_out, consts, aux) if has_aux else (program, pvals_out, consts)
def partial_run_program( program: Program, in_unknowns: List[bool], instantiate: Optional[List[bool]] = None) -> Tuple[Program, Program, List[bool], int]:
2289def partial_run_program(
2290    program: Program,
2291    in_unknowns: List[bool],
2292    instantiate: Optional[List[bool]] = None,
2293) -> Tuple[Program, Program, List[bool], int]:
2294    env: Dict[Var, bool] = {}
2295    residuals: Set[Var] = set()
2296
2297    def read(x: Atom) -> bool:
2298        return type(x) is Var and env[x]
2299
2300    def write(unk: bool, v: Var) -> None:
2301        env[v] = unk
2302
2303    instructions1, instructions2 = [], []
2304    list_map(write, in_unknowns, program.in_binders)
2305
2306    for instruction in program.instructions:
2307        unks_in = list_map(read, instruction.inputs)
2308        (
2309            instruction1,
2310            instruction2,
2311            unks_out,
2312            res,
2313        ) = instruction.op.partial_run_instruction(unks_in, instruction)
2314        if instruction1 is not None:
2315            instructions1 += [instruction1]
2316        if instruction2 is not None:
2317            instructions2 += [instruction2]
2318        if res is not None:
2319            residuals.update(res)
2320        list_map(write, unks_out, instruction.out_binders)
2321
2322    out_unknowns = list_map(read, program.outs)
2323    if instantiate is not None:
2324        for v, uk, inst in zip(program.outs, out_unknowns, instantiate):
2325            if inst and not uk:
2326                if type(v) is Var:
2327                    residuals.add(v)
2328        out_unknowns = list_map(operator_py.or_, out_unknowns, instantiate)
2329
2330    residuals, num_res = list(residuals), len(residuals)
2331    assert all(type(v) is Var for v in residuals), residuals
2332
2333    ins1, ins2 = partition_list(in_unknowns, program.in_binders)
2334    outs1, outs2 = partition_list(out_unknowns, program.outs)
2335
2336    program1 = Program(
2337        ins1,
2338        instructions1,
2339        outs1 + residuals,
2340        0,
2341        program.static_args,
2342        f"{program.name}_partial1",
2343    )
2344    program2 = Program(
2345        residuals + ins2,
2346        instructions2,
2347        outs2,
2348        0,
2349        program.static_args,
2350        f"{program.name}_partial2",
2351    )
2352    typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2)
2353
2354    return program1, program2, out_unknowns, num_res
def typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2):
2357def typecheck_partial_run_program(program, in_unknowns, out_unknowns, program1, program2):
2358    programty = typecheck_program(program)  # (a1,  a2) -> (b1, b2 )
2359    program1ty = typecheck_program(program1)  #  a1       -> (b1, res)
2360    program2ty = typecheck_program(program2)  # (res, a2) -> b2
2361
2362    a1, a2 = partition_list(in_unknowns, programty.in_types)
2363    b1, b2 = partition_list(out_unknowns, programty.out_types)
2364    b1_, res = split_list(program1ty.out_types, len(b1))
2365    res_, a2_ = split_list(program2ty.in_types, len(res))
2366    b2_ = program2ty.out_types
2367
2368    a1 = tuple(a1)
2369    a2, a2_ = tuple(a2), tuple(a2_)
2370    b1, b1_ = tuple(b1), tuple(b1_)
2371    b2, b2_ = tuple(b2), tuple(b2_)
2372    res, res_ = tuple(res), tuple(res_)
2373
2374    if program1ty.in_types != a1:
2375        raise TypeError
2376    if program2ty.out_types != b2:
2377        raise TypeError
2378    if b1 != b1_:
2379        raise TypeError
2380    if res != res_:
2381        raise TypeError
2382    if a2 != a2_:
2383        raise TypeError
2384    if b2 != b2_:
2385        raise TypeError
def linearize_flat(f, *primals_in, has_aux):
2388def linearize_flat(f, *primals_in, has_aux):
2389    pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in]
2390
2391    def f_jvp(*primals_tangents_in):
2392        jvp_ret = jvp(f, *split_half(primals_tangents_in), has_aux=has_aux)
2393        if has_aux:
2394            (primals_out, tangents_out), aux = jvp_ret
2395            return ((*primals_out, *tangents_out), aux)
2396        else:
2397            primals_out, tangents_out = jvp_ret
2398            return (*primals_out, *tangents_out)
2399
2400    partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux)
2401    if has_aux:
2402        program, pvals_out, consts, aux = partial_run_flat_ret
2403    else:
2404        program, pvals_out, consts = partial_run_flat_ret
2405    primal_pvals, _ = split_half(pvals_out)
2406    assert all(pval.is_known for pval in primal_pvals)
2407    primals_out = [pval.const for pval in primal_pvals]
2408    f_lin = lambda *tangents: run_program(program, [*consts, *tangents])
2409    return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
def linearize(f, *primals_in, has_aux=False):
2412def linearize(f, *primals_in, has_aux=False):
2413    primals_in_flat, in_tree = tree_flatten(primals_in)
2414    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2415    linearize_flat_ret = linearize_flat(f, *primals_in_flat, has_aux=has_aux)
2416    if has_aux:
2417        primals_out_flat, f_lin_flat, aux = linearize_flat_ret
2418    else:
2419        primals_out_flat, f_lin_flat = linearize_flat_ret
2420
2421    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2422
2423    def f_lin(*tangents_in):
2424        tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
2425        if in_tree != in_tree2:
2426            raise TypeError
2427        tangents_out_flat = f_lin_flat(*tangents_in_flat)
2428        return tree_unflatten(out_tree_store(), tangents_out_flat)
2429
2430    return (primals_out, f_lin, aux) if has_aux else (primals_out, f_lin)
def tracers_to_program( tracers_in: List[PartialRunTraceTensor], tracers_out: List[PartialRunTraceTensor]):
2433def tracers_to_program(
2434    tracers_in: List["PartialRunTraceTensor"],
2435    tracers_out: List["PartialRunTraceTensor"],
2436):
2437    def tracer_parents(t: PartialRunTraceTensor) -> List[PartialRunTraceTensor]:
2438        return t.draft.tracers_in if isinstance(t.draft, InstructionDraft) else []
2439
2440    def draft_to_instruction(tracer_to_var: Dict[int, Var], draft: InstructionDraft) -> Instruction:
2441        inputs = [tracer_to_var[id(t)] for t in draft.tracers_in]
2442        out_binders = [Var(symval) for symval in draft.symvals_out]
2443        for t_ref, var in list_zip(draft.tracer_refs_out, out_binders):
2444            if t_ref() is not None:
2445                tracer_to_var[id(t_ref())] = var
2446        return Instruction(draft.prim, inputs, draft.params, out_binders)
2447
2448    tracer_to_var: Dict[int, Var] = {id(t): Var(SymbolicTensor.like(t.symval)) for t in tracers_in}
2449    constvar_to_val: Dict[int, Any] = {}
2450    constid_to_var: Dict[int, Var] = {}
2451    processed_instructions: Set[int] = set()
2452    instructions: List[Instruction] = []
2453    for t in toposort(tracers_out, tracer_parents):
2454        if isinstance(t.draft, LambdaBindingDraft):
2455            assert id(t) in set(list_map(id, tracers_in))
2456        elif isinstance(t.draft, ConstDraft):
2457            val = t.draft.val
2458            var = constid_to_var.get(id(val))
2459            if var is None:
2460                symval = SymbolicTensor.like(get_symval(val))
2461                var = constid_to_var[id(val)] = Var(symval)
2462                constvar_to_val[var] = val
2463            tracer_to_var[id(t)] = var
2464        elif isinstance(t.draft, InstructionDraft):
2465            if id(t.draft) not in processed_instructions:
2466                instructions += [draft_to_instruction(tracer_to_var, t.draft)]
2467                processed_instructions.add(id(t.draft))
2468        else:
2469            raise TypeError(t.draft)
2470
2471    constvars, constvals = unzip2(constvar_to_val.items())
2472    in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
2473    # out_vars = [tracer_to_var[id(t)] for t in tracers_out if id(t) in tracer_to_var]
2474    out_vars = [tracer_to_var[id(t)] for t in tracers_out]
2475    program = Program(tuple(in_binders), tuple(instructions), tuple(out_vars))
2476    typecheck_program(program)
2477    return program, constvals
def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
2480def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):
2481    def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):
2482        seen = set()
2483        for node in nodes:
2484            assert all(id(parent) in seen for parent in parents(node))
2485            seen.add(id(node))
2486
2487    def remove_duplicates(lst):
2488        seen = set()
2489        return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
2490
2491    if not out_nodes:
2492        return []
2493    out_nodes = remove_duplicates(out_nodes)
2494
2495    child_counts = {}
2496    stack = list(out_nodes)
2497    while stack:
2498        node = stack.pop()
2499        if id(node) in child_counts:
2500            child_counts[id(node)] += 1
2501        else:
2502            child_counts[id(node)] = 1
2503            stack.extend(parents(node))
2504    for node in out_nodes:
2505        child_counts[id(node)] -= 1
2506
2507    sorted_nodes = []
2508    childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
2509    while childless_nodes:
2510        node = childless_nodes.pop()
2511        sorted_nodes += [node]
2512        for parent in parents(node):
2513            if child_counts[id(parent)] == 1:
2514                childless_nodes += [parent]
2515            else:
2516                child_counts[id(parent)] -= 1
2517
2518    sorted_nodes = sorted_nodes[::-1]
2519    check_toposort(sorted_nodes, parents)
2520    return sorted_nodes
def vjp_flat(f, *primals_in, has_aux=False, **static_args):
2523def vjp_flat(f, *primals_in, has_aux=False, **static_args):
2524    pvals_in = [make_known_pval(x) for x in primals_in] + [make_unknown_pval(SymbolicTensor.like(get_symval(x))) for x in primals_in]
2525    primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
2526    del primal_pvals_in
2527
2528    def f_jvp(*primals_tangents_in):
2529        jvp_ret = jvp(
2530            f,
2531            *split_half(primals_tangents_in),
2532            has_aux=has_aux,
2533            global_data="vjp",
2534            **static_args,
2535        )
2536        if has_aux:
2537            ((primals_out, tangents_out), aux) = jvp_ret
2538        else:
2539            (primals_out, tangents_out) = jvp_ret
2540        return ([*primals_out, *tangents_out], aux) if has_aux else [*primals_out, *tangents_out]
2541
2542    partial_run_flat_ret = partial_run_flat(f_jvp, pvals_in, has_aux, "vjp")
2543    if has_aux:
2544        program, pvals_out, consts, aux = partial_run_flat_ret
2545    else:
2546        program, pvals_out, consts = partial_run_flat_ret
2547
2548    primal_pvals, tangent_pvals = split_half(pvals_out)
2549    del tangent_pvals
2550    assert all(pval.is_known for pval in primal_pvals)
2551    primals_out_flat = [pval.const for pval in primal_pvals]
2552    transpose_inputs = consts + [UndefinedPrimal(t.symval) for t in tangent_pvals_in]
2553
2554    def f_vjp_flat(*cotangents):
2555        # return backward_pass(program, transpose_inputs, cotangents)
2556        undef_primals = tuple(isinstance(x, UndefinedPrimal) for x in transpose_inputs)
2557        transposed_program, new_consts = transpose_program(program, undef_primals)
2558        residuals, _ = partition_list(undef_primals, transpose_inputs)
2559        outs = run_program(transposed_program, (*new_consts, *residuals, *cotangents))
2560        return outs
2561
2562    return (primals_out_flat, f_vjp_flat, aux) if has_aux else (primals_out_flat, f_vjp_flat)
def vjp(f, *primals_in, has_aux=False, **static_args):
2565def vjp(f, *primals_in, has_aux=False, **static_args):
2566    primals_in_flat, in_tree = tree_flatten(primals_in)
2567    f, out_tree_store = flatten_fn(f, in_tree, has_aux=has_aux)
2568    vjp_ret = vjp_flat(f, *primals_in_flat, has_aux=has_aux, **static_args)
2569    if has_aux:
2570        primals_out_flat, f_vjp_flat, aux = vjp_ret
2571    else:
2572        primals_out_flat, f_vjp_flat = vjp_ret
2573    primals_out = tree_unflatten(out_tree_store(), primals_out_flat)
2574
2575    def f_vjp(*cotangents_out):
2576        cotangents_out_flat, _ = tree_flatten(cotangents_out)
2577        cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
2578
2579        return tree_unflatten(in_tree, cotangents_in_flat)
2580
2581    return (primals_out, f_vjp, aux) if has_aux else (primals_out, f_vjp)
NullCotangent = None
def backward_pass( program: Program, args: List[Any], cotangents: List[Any]) -> List[Any]:
2587def backward_pass(program: Program, args: List[Any], cotangents: List[Any]) -> List[Any]:
2588    primal_env: Dict[Var, Any] = {}
2589    ct_env: Dict[Var, Any] = {}
2590
2591    def read_primal(x: Atom) -> Any:
2592        return primal_env.get(x, UndefinedPrimal(x.symval)) if type(x) is Var else x.val
2593
2594    def write_primal(v: Var, val: Any) -> None:
2595        if type(val) is not UndefinedPrimal:
2596            primal_env[v] = val
2597
2598    def read_cotangent(v: Var) -> Any:
2599        return ct_env.pop(v, backend.zeros(v.symval.shape, v.symval.dtype))
2600
2601    def write_cotangent(x: Atom, ct: Any):
2602        if type(x) is Var and ct is not NullCotangent:
2603            ct_env[x] = (ct_env[x] + ct) if x in ct_env else ct
2604
2605    list_map(write_primal, program.in_binders, args)
2606    list_map(write_cotangent, program.outs, cotangents)
2607    for instruction in program.instructions[::-1]:
2608        primals_in = list_map(read_primal, instruction.inputs)
2609        cotangents_in = list_map(read_cotangent, instruction.out_binders)
2610        inp, params = primals_in, instruction.params
2611        cotangents_out = instruction.op.T(cotangents_in, *inp, **params)
2612        list_map(write_cotangent, instruction.inputs, cotangents_out)
2613
2614    ret = [read_cotangent(v) for v, x in list_zip(program.in_binders, args) if isinstance(x, UndefinedPrimal)]
2615    return ret
def transpose_program(*args, **kwargs) -> Any:
162        def decorated_function(*args, **kwargs) -> Any:
163            result = wrapper(*args, **kwargs)
164            cache_info = wrapper.cache_info()
165
166            dblog(
167                f"{fn.__name__}.{cache_info} {args.__hash__()}",
168                enable=backend.LOG_LRU,
169            )
170            tb = "".join(traceback.format_list(traceback.extract_stack())[tb_start:tb_end]).replace("\n    ", ":\t") + "-" * 20 + "\n"
171            dblog(f"{tb}", enable=backend.LOG_LRU)
172
173            return result
def grad(f, argnums=(0,), argnames='', has_aux=False, return_value=False):
2636def grad(f, argnums=(0,), argnames="", has_aux=False, return_value=False):
2637    f, rejit = (f, False) if not isinstance(f, jit) else (f.f, True)
2638    if isinstance(argnums, int):
2639        argnums = (argnums,)
2640
2641    def gfn(x, *xs, **static_args):
2642        vjp_ret = vjp(f, x, *xs, has_aux=has_aux, **static_args)
2643        if has_aux:
2644            y, f_vjp, aux = vjp_ret
2645        else:
2646            y, f_vjp = vjp_ret
2647        if np.shape(y) != ():
2648            raise TypeError("grad output must be 0-dim scalar with shape ()")
2649        gL_xs = f_vjp(backend.ones(()))
2650        gL_xs = tuple(gL_xs[i] for i in argnums) if len(argnums) > 1 else gL_xs[argnums[0]]
2651        if return_value:
2652            return ((y, aux), gL_xs) if has_aux else (y, gL_xs)
2653        else:
2654            return (gL_xs, aux) if has_aux else gL_xs
2655
2656    return jit(gfn) if rejit else gfn
def value_and_grad(f, argnums=(0,), argnames='', has_aux=False):
2659def value_and_grad(f, argnums=(0,), argnames="", has_aux=False):
2660    return grad(
2661        f,
2662        argnums=argnums,
2663        argnames=argnames,
2664        has_aux=has_aux,
2665        return_value=True,
2666    )
def jit_partial_run(trace, tracers, *, program):
2669def jit_partial_run(trace, tracers, *, program):
2670    in_unknowns = [not t.pval.is_known for t in tracers]
2671    program1, program2, out_unknowns, num_res = partial_run_program(program, in_unknowns)
2672    known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
2673    known_vals = [t.pval.const for t in known_tracers]
2674    outs1_res = backend.jit_op(*known_vals, program=program)
2675    outs1, res = split_list(outs1_res, len(program1.outs) - num_res)
2676    res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
2677    outs2 = [PartialRunTraceTensor(trace, PartialValue.unknown(v.symval), None) for v in program2.outs]
2678    draft = InstructionDraft(
2679        backend.jit_op,
2680        res_tracers + unknown_tracers,
2681        dict(program=program2),
2682        [v.symval for v in program2.outs],
2683        map(weakref.ref, outs2),
2684    )
2685    for t in outs2:
2686        t.draft = draft
2687    return merge_lists(out_unknowns, outs1, outs2)
class jit:
2690class jit:
2691    def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None):
2692        if isinstance(static_argnames, str):
2693            static_argnames = tuple(static_argnames.split(" "))
2694        assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames)
2695        self.f = f
2696        self.name = name if name is not None else self.f.__name__
2697        self.static_argnames = static_argnames
2698        self.dynamic_axes = dynamic_axes
2699
2700    @classmethod
2701    def with_options(cls, **kwargs):
2702        return partial(cls, **kwargs)
2703
2704    @classmethod
2705    def get_jit_name(cls, args, static_args, prefix="jit", short=False):
2706        name = f"{prefix}_"
2707        if short:
2708            static_args_tup = tuple(static_args.items())
2709            ids = repr(hash((prefix, args, static_args_tup)))[-4:]
2710            name = f"{prefix}_{ids}"
2711        else:
2712            for a in args:
2713                name += f"shape_{a.shape}_dtype_{a.dtype.name}_"
2714            for k, v in static_args.items():
2715                name += f"{k}_{v}_"
2716            name = name.replace("(", "L")
2717            name = name.replace(")", "R")
2718            name = name.replace(",", "C")
2719            name = name.replace(" ", "")
2720            name = name.replace(".", "D")
2721
2722        return name
2723
2724    def get_program(self, *args, **static_args):
2725        sig = inspect.signature(self.f)
2726        if all("*" not in repr(v) for v in sig.parameters.values()):
2727            args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames]
2728            static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames]
2729
2730            if args:
2731                if len(args) > len(args_strs):
2732                    assert static_args_strs
2733                    args, rest = args[: len(args_strs)], args[len(args_strs) :]
2734                    new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args}
2735                    static_args = {**new_static_args, **static_args}
2736            else:
2737                args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)])
2738
2739        symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args)
2740        static_args = tuple(static_args.items())
2741        if self.name is None:
2742            self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}"
2743        program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name)
2744        return program, consts, out_tree
2745
2746    def __call__(self, *args, **static_args):
2747        program, consts, out_tree = self.get_program(*args, **static_args)
2748        args, in_tree = tree_flatten(args)
2749        outs = bind(backend.jit_op, *consts, *args, program=program)
2750        return tree_unflatten(out_tree, outs)
2751
2752    def lower(self, *args, **static_args):
2753        program, consts, out_tree = self.get_program(*args, **static_args)
2754        args, in_tree = tree_flatten(args)
2755        hashed_program = Hashed(program)
2756        num_consts = program.num_consts
2757        consts, args = args[:num_consts], args[num_consts:]
2758        hashed_consts = tuple(map(Hashed, consts))
2759        jit_output = backend.jit_program(hashed_program, hashed_consts)
2760        return jit_output
2761
2762    def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs):
2763        if isinstance(args, Tensor):
2764            args, static_args = (args,), dict()
2765        elif not isinstance(args[-1], dict):
2766            assert all(isinstance(a, Tensor) for a in args)
2767            static_args = dict()
2768        else:
2769            args, static_args = args[:-1], args[-1]
2770        assert isinstance(args, (tuple, list)) and isinstance(static_args, dict)
2771        jit_output = self.lower(*args, **static_args)
2772        backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)
jit(f, static_argnames=(), name=None, dynamic_axes=None)
2691    def __init__(self, f, static_argnames=(), name=None, dynamic_axes=None):
2692        if isinstance(static_argnames, str):
2693            static_argnames = tuple(static_argnames.split(" "))
2694        assert type(static_argnames) is tuple and all(type(s) is str for s in static_argnames)
2695        self.f = f
2696        self.name = name if name is not None else self.f.__name__
2697        self.static_argnames = static_argnames
2698        self.dynamic_axes = dynamic_axes
f
name
static_argnames
dynamic_axes
@classmethod
def with_options(cls, **kwargs):
2700    @classmethod
2701    def with_options(cls, **kwargs):
2702        return partial(cls, **kwargs)
@classmethod
def get_jit_name(cls, args, static_args, prefix='jit', short=False):
2704    @classmethod
2705    def get_jit_name(cls, args, static_args, prefix="jit", short=False):
2706        name = f"{prefix}_"
2707        if short:
2708            static_args_tup = tuple(static_args.items())
2709            ids = repr(hash((prefix, args, static_args_tup)))[-4:]
2710            name = f"{prefix}_{ids}"
2711        else:
2712            for a in args:
2713                name += f"shape_{a.shape}_dtype_{a.dtype.name}_"
2714            for k, v in static_args.items():
2715                name += f"{k}_{v}_"
2716            name = name.replace("(", "L")
2717            name = name.replace(")", "R")
2718            name = name.replace(",", "C")
2719            name = name.replace(" ", "")
2720            name = name.replace(".", "D")
2721
2722        return name
def get_program(self, *args, **static_args):
2724    def get_program(self, *args, **static_args):
2725        sig = inspect.signature(self.f)
2726        if all("*" not in repr(v) for v in sig.parameters.values()):
2727            args_strs = [k for k, v in sig.parameters.items() if k != "self" and k not in self.static_argnames]
2728            static_args_strs = [k for k, v in sig.parameters.items() if k != "self" and k in self.static_argnames]
2729
2730            if args:
2731                if len(args) > len(args_strs):
2732                    assert static_args_strs
2733                    args, rest = args[: len(args_strs)], args[len(args_strs) :]
2734                    new_static_args = {k: rest_arg for k, rest_arg in zip(static_args_strs, rest) if k not in static_args}
2735                    static_args = {**new_static_args, **static_args}
2736            else:
2737                args = tuple([static_args[k] if k in static_args else arg for k, arg in zip(args_strs, args)])
2738
2739        symvals_in = tree_map(lambda x: SymbolicTensor.like(get_symval(x)), args)
2740        static_args = tuple(static_args.items())
2741        if self.name is None:
2742            self.name = f"jit_{str(hash((self.f, symvals_in, static_args)))[-5:]}"
2743        program, consts, out_tree = make_program(self.f, *symvals_in, static_args=static_args, name=self.name)
2744        return program, consts, out_tree
def lower(self, *args, **static_args):
2752    def lower(self, *args, **static_args):
2753        program, consts, out_tree = self.get_program(*args, **static_args)
2754        args, in_tree = tree_flatten(args)
2755        hashed_program = Hashed(program)
2756        num_consts = program.num_consts
2757        consts, args = args[:num_consts], args[num_consts:]
2758        hashed_consts = tuple(map(Hashed, consts))
2759        jit_output = backend.jit_program(hashed_program, hashed_consts)
2760        return jit_output
def export( self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs):
2762    def export(self, output_path, args, export_params=True, input_names=None, output_names=None, **kwargs):
2763        if isinstance(args, Tensor):
2764            args, static_args = (args,), dict()
2765        elif not isinstance(args[-1], dict):
2766            assert all(isinstance(a, Tensor) for a in args)
2767            static_args = dict()
2768        else:
2769            args, static_args = args[:-1], args[-1]
2770        assert isinstance(args, (tuple, list)) and isinstance(static_args, dict)
2771        jit_output = self.lower(*args, **static_args)
2772        backend.export(jit_output, output_path, export_params, input_names, output_names, **kwargs)