diff --git a/apprentice/agents/cre_agents/cre_agent.py b/apprentice/agents/cre_agents/cre_agent.py index afc25c3..e7e92ba 100644 --- a/apprentice/agents/cre_agents/cre_agent.py +++ b/apprentice/agents/cre_agents/cre_agent.py @@ -1409,7 +1409,7 @@ def _recover_prev_skill_app(self, action=None, skill_app = None if(skill_uid in self.skills and skill_app_uid is not None): - if(uid in getattr(self.skills[skill_uid],"skill_apps", [])): + if(self.uid in getattr(self.skills[skill_uid],"skill_apps", [])): skill_app = self.skills[skill_uid].skill_apps[skill_app_uid] return skill_app @@ -1749,7 +1749,7 @@ def get_skills(self, skill_uids=None, skill_labels=None, if(skill_uids): skills = [self.skills[uid] for uid in skill_uids if uid in self.skills] elif(skill_labels): - skills = chain([skills_by_label.get(label,[]) for label in skill_labels]) + skills = chain([self.skills_by_label.get(label,[]) for label in skill_labels]) elif(states): raise NotImplemented() else: diff --git a/apprentice/agents/cre_agents/funcs.py b/apprentice/agents/cre_agents/funcs.py index 93b5328..c56f99b 100644 --- a/apprentice/agents/cre_agents/funcs.py +++ b/apprentice/agents/cre_agents/funcs.py @@ -1,4 +1,4 @@ -from numba.types import f8, string, boolean +from numba.types import f8, string, boolean, UniTuple, unicode_type from apprentice.agents.cre_agents.extending import registries, new_register_decorator, new_register_all from apprentice.agents.cre_agents.environment import TextField from cre import CREFunc @@ -7,6 +7,82 @@ register_func = new_register_decorator("func", full_descr="CREFunc") register_all_funcs = new_register_all("func", types=[CREFunc], full_descr="CREFunc") +# --- Product rule: a^{m} * a^{n} -> a^{m + n} +@CREFunc(signature=string(string), shorthand='exp_product_rule({0})') +def ProductRule(s: str) -> str: + # make robust like PowerRule + s = s.rsplit("/", 1)[-1].strip() + + # split on either "\cdot" or "*" + if r'\cdot' in s: + left, right = [t.strip() for t in s.split(r'\cdot', 1)] + else: + left, right = [t.strip() for t in s.split('*', 1)] + + # Parse left: a^{m} + iL = left.index('^{'); jL = iL + 2; kL = left.index('}', jL) + baseL = left[:iL] + m = left[jL:kL].strip() + + # Parse right: a^{n} + iR = right.index('^{'); jR = iR + 2; kR = right.index('}', jR) + baseR = right[:iR] + n = right[jR:kR].strip() + + # optional safety + if baseL != baseR: + raise ValueError("ProductRule expects matching bases.") + + return f"{baseL}^{{{m} + {n}}}" + + +@CREFunc(signature=string(string), shorthand='exp_product_simplify({0})') +def SimplifyProduct(s: str) -> str: + s = s.rsplit("/", 1)[-1].strip() + i = s.index('^{'); j = i + 2; k = s.index('}', j) + base = s[:i] + m_str, n_str = [t.strip() for t in s[j:k].split('+', 1)] + val = int(m_str) + int(n_str) + return f"{base}^{{{val}}}" + + + +# --- Quotient rule: a^{m} / a^{n} -> a^{m - n} +@CREFunc(signature=string(string), shorthand='exp_quotient_rule({0})') +def QuotientRule(s: str) -> str: + # print("INITIAL PROBLEM", s) + + inner = s[len(r"\frac{"):-1] + left, right = inner.split("}{") + + baseL, m = left.split("^{") + m = m[:-1] + baseR, n = right.split("^{") + n = n[:-1] + + # print("FOUND M", baseL, m) + # print("FOUND N", baseR, n) + + if baseL != baseR: + raise ValueError("QuotientRule expects matching bases.") + + return f"{baseL}^{{-{n} + {m}}}" + + +@CREFunc(signature=string(string), shorthand='exp_quotient_simplify({0})') +def SimplifyQuotient(s: str) -> str: + # s = s.rsplit("/", 1)[-1].strip() + i = s.index('^{'); j = i + 2; k = s.index('}', j) + print("I", i) + base = s[:i] + print("BASE", base) + n_str, m_str = [t.strip() for t in s[j:k].split('+', 1)] + print("M and N", m_str, n_str) + val = int(m_str) + int(n_str) + print("VAL", val) + return f"{base}^{{{val}}}" + + @CREFunc(signature=boolean(string,string), shorthand = '{0} == {1}', commutes=True) @@ -158,6 +234,67 @@ def AcrossMultiply(a, b): return (float(a.value) * float(b.value)) +# === Prior Knowledge & Cross-Domain Functions === + +@CREFunc(signature=f8(TextField), shorthand='Num({0})') +def Num(tf): + """Convert TextField value to a float.""" + return float(str(tf.value).replace(',', '').strip()) + +@CREFunc(signature=string(TextField), shorthand='Str({0})') +def StrTF(tf): + """Extract string from a TextField.""" + return str(tf.value) + +@CREFunc(signature=boolean(TextField), shorthand='IsDen({0})') +def IsDen(tf): + """Return True if this TextField looks like a denominator.""" + return 'den' in tf.id + +@CREFunc(signature=boolean(TextField), shorthand='IsNum({0})') +def IsNum(tf): + """Return True if this TextField looks like a numerator.""" + return 'num' in tf.id + +@CREFunc(signature=boolean(string, string), shorthand='lower({0}) == lower({1})', commutes=True) +def EqualsIgnoreCase(a, b): + """Case-insensitive equality for strings.""" + return a.lower() == b.lower() + +@CREFunc(signature=f8(string), shorthand='parse({0})') +def ParseFloat(s): + """Parse numeric string into float.""" + return float(s.replace(',', '').strip()) + +@CREFunc(signature=string(string, string), shorthand='join({0},{1})', commutes=False) +def JoinWithSpace(a, b): + """Join two strings with a space.""" + return f"{a} {b}" + +@CREFunc(signature=f8(f8), shorthand='sq({0})') +def Sq(a): + """Square a number.""" + return a * a + +@CREFunc(signature=f8(f8, f8), shorthand='sq({0}) + {1}') +def SqPlus(a, b): + """Square a number and add another.""" + return (a * a) + b + +@CREFunc(signature=f8(f8, f8), shorthand='Reuse({0},{1})', commutes=True) +def ReusePrior(a, b): + """Prefer previously known value a; otherwise use b.""" + return a if a != 0 else b + +@CREFunc(signature=f8(f8), shorthand='Recall({0})') +def RecallValue(a): + """Surface a previously computed value.""" + return a + +@CREFunc(signature=f8(f8, f8), shorthand='AvgPrior({0},{1})', commutes=True) +def AveragePrior(a, b): + """Combine earlier results (simple average).""" + return (a + b) / 2 ##### Define all CREFuncs above this line #####