Skip to content

Commit 3d9145a

Browse files
committed
refactor code
1 parent f4f601f commit 3d9145a

File tree

19 files changed

+60
-123
lines changed

19 files changed

+60
-123
lines changed

minitorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MODULE 0
22
# MODULE 1
3-
import minitorch.scalar_functions as scalar_functions # noqa: F401,F403
3+
import minitorch.scalar.functions as scalar_functions # noqa: F401,F403
44

55
from .autodiff import * # noqa: F401,F403
66
from .cuda_ops import * # noqa: F401,F403
@@ -15,7 +15,7 @@
1515
from .nn import * # noqa: F401,F403
1616
from .optim import * # noqa: F401,F403
1717
from .scalar import Scalar, ScalarHistory, derivative_check # noqa: F401,F403
18-
from .scalar_functions import ScalarFunction # noqa: F401,F403
18+
from minitorch.scalar.functions import ScalarFunction # noqa: F401,F403
1919

2020
# MODULE 2
2121
from .tensor import * # noqa: F401,F403

minitorch/autodiff.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
from typing_extensions import Protocol
55

6-
# ## Task 1.1
7-
# Central Difference calculation
8-
96

107
def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6) -> Any:
118
r"""
File renamed without changes.
File renamed without changes.

minitorch/modules.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

minitorch/scalar/__init__.py

Whitespace-only changes.

minitorch/scalar_functions.py renamed to minitorch/scalar/functions.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import minitorch
66

7-
from . import operators
8-
from .autodiff import Context
7+
from .. import common_operators
8+
from ..autodiff import Context
99

1010
if TYPE_CHECKING:
1111
from typing import Tuple
@@ -86,12 +86,12 @@ class Log(ScalarFunction):
8686
@staticmethod
8787
def forward(ctx: Context, a: float) -> float:
8888
ctx.save_for_backward(a)
89-
return operators.log(a)
89+
return common_operators.log(a)
9090

9191
@staticmethod
9292
def backward(ctx: Context, d_output: float) -> float:
9393
(a,) = ctx.saved_values
94-
return operators.log_back(a, d_output)
94+
return common_operators.log_back(a, d_output)
9595

9696

9797
class Mul(ScalarFunction):
@@ -100,7 +100,7 @@ class Mul(ScalarFunction):
100100
@staticmethod
101101
def forward(ctx: Context, a: float, b: float) -> float:
102102
ctx.save_for_backward(a, b)
103-
return operators.mul(a, b)
103+
return common_operators.mul(a, b)
104104

105105
@staticmethod
106106
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
@@ -114,20 +114,20 @@ class Inv(ScalarFunction):
114114
@staticmethod
115115
def forward(ctx: Context, a: float) -> float:
116116
ctx.save_for_backward(a)
117-
return operators.inv(a)
117+
return common_operators.inv(a)
118118

119119
@staticmethod
120120
def backward(ctx: Context, d_output: float) -> float:
121121
(a,) = ctx.saved_values
122-
return operators.inv_back(a, d_output)
122+
return common_operators.inv_back(a, d_output)
123123

124124

125125
class Neg(ScalarFunction):
126126
"Negation function"
127127

128128
@staticmethod
129129
def forward(ctx: Context, a: float) -> float:
130-
return operators.neg(a)
130+
return common_operators.neg(a)
131131

132132
@staticmethod
133133
def backward(ctx: Context, d_output: float) -> float:
@@ -139,7 +139,7 @@ class Sigmoid(ScalarFunction):
139139

140140
@staticmethod
141141
def forward(ctx: Context, a: float) -> float:
142-
s = operators.sigmoid(a)
142+
s = common_operators.sigmoid(a)
143143
ctx.save_for_backward(s)
144144
return s
145145

@@ -155,20 +155,20 @@ class ReLU(ScalarFunction):
155155
@staticmethod
156156
def forward(ctx: Context, a: float) -> float:
157157
ctx.save_for_backward(a)
158-
return operators.relu(a)
158+
return common_operators.relu(a)
159159

160160
@staticmethod
161161
def backward(ctx: Context, d_output: float) -> float:
162162
(a,) = ctx.saved_values
163-
return operators.relu_back(a, d_output)
163+
return common_operators.relu_back(a, d_output)
164164

165165

166166
class Exp(ScalarFunction):
167167
"Exp function"
168168

169169
@staticmethod
170170
def forward(ctx: Context, a: float) -> float:
171-
e = operators.exp(a)
171+
e = common_operators.exp(a)
172172
ctx.save_for_backward(e)
173173
return e
174174

@@ -183,7 +183,7 @@ class LT(ScalarFunction):
183183

184184
@staticmethod
185185
def forward(ctx: Context, a: float, b: float) -> float:
186-
return operators.lt(a, b)
186+
return common_operators.lt(a, b)
187187

188188
@staticmethod
189189
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
@@ -195,7 +195,7 @@ class EQ(ScalarFunction):
195195

196196
@staticmethod
197197
def forward(ctx: Context, a: float, b: float) -> float:
198-
return operators.eq(a, b)
198+
return common_operators.eq(a, b)
199199

200200
@staticmethod
201201
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:

minitorch/scalar.py renamed to minitorch/scalar/scalar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import numpy as np
77

8-
from .autodiff import Context, Variable, backpropagate, central_difference
9-
from .scalar_functions import (
8+
from ..autodiff import Context, Variable, backpropagate, central_difference
9+
from .functions import (
1010
EQ,
1111
LT,
1212
Add,

minitorch/tensor/__init__.py

Whitespace-only changes.

minitorch/tensor_data.py renamed to minitorch/tensor/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy import array, float64
1010
from typing_extensions import TypeAlias
1111

12-
from .operators import prod
12+
from ..common_operators import prod
1313

1414
MAX_DIMS = 32
1515

0 commit comments

Comments
 (0)