Skip to content

Commit fe94910

Browse files
authored
feat: use EUC for AI IF, CLASSIFY, and SCORE when connection is not provided (#2507)
Fixes b/489038951 🦕
1 parent a5ddcea commit fe94910

File tree

11 files changed

+33
-21
lines changed
  • bigframes
  • tests/unit/core/compile/sqlglot/expressions
  • third_party/bigframes_vendored/ibis/expr/operations

11 files changed

+33
-21
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def if_(
745745
or pandas Series.
746746
connection_id (str, optional):
747747
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
748-
If not provided, the connection from the current session will be used.
748+
If not provided, the query uses your end-user credential.
749749
750750
Returns:
751751
bigframes.series.Series: A new series of bools.
@@ -756,7 +756,7 @@ def if_(
756756

757757
operator = ai_ops.AIIf(
758758
prompt_context=tuple(prompt_context),
759-
connection_id=_resolve_connection_id(series_list[0], connection_id),
759+
connection_id=connection_id,
760760
)
761761

762762
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -800,7 +800,7 @@ def classify(
800800
Categories to classify the input into.
801801
connection_id (str, optional):
802802
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
803-
If not provided, the connection from the current session will be used.
803+
If not provided, the query uses your end-user credential.
804804
805805
Returns:
806806
bigframes.series.Series: A new series of strings.
@@ -812,7 +812,7 @@ def classify(
812812
operator = ai_ops.AIClassify(
813813
prompt_context=tuple(prompt_context),
814814
categories=tuple(categories),
815-
connection_id=_resolve_connection_id(series_list[0], connection_id),
815+
connection_id=connection_id,
816816
)
817817

818818
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -853,7 +853,7 @@ def score(
853853
or pandas Series.
854854
connection_id (str, optional):
855855
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
856-
If not provided, the connection from the current session will be used.
856+
If not provided, the query uses your end-user credential.
857857
858858
Returns:
859859
bigframes.series.Series: A new series of double (float) values.
@@ -864,7 +864,7 @@ def score(
864864

865865
operator = ai_ops.AIScore(
866866
prompt_context=tuple(prompt_context),
867-
connection_id=_resolve_connection_id(series_list[0], connection_id),
867+
connection_id=connection_id,
868868
)
869869

870870
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
113113
)
114114
)
115115

116-
endpoit = op_args.get("endpoint", None)
117-
if endpoit is not None:
118-
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116+
endpoint = op_args.get("endpoint", None)
117+
if endpoint is not None:
118+
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint)))
119119

120120
request_type = op_args.get("request_type", None)
121121
if request_type is not None:

bigframes/operations/ai_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class AIIf(base_ops.NaryOp):
123123
name: ClassVar[str] = "ai_if"
124124

125125
prompt_context: Tuple[str | None, ...]
126-
connection_id: str
126+
connection_id: str | None
127127

128128
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
129129
return dtypes.BOOL_DTYPE
@@ -135,7 +135,7 @@ class AIClassify(base_ops.NaryOp):
135135

136136
prompt_context: Tuple[str | None, ...]
137137
categories: tuple[str, ...]
138-
connection_id: str
138+
connection_id: str | None
139139

140140
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
141141
return dtypes.STRING_DTYPE
@@ -146,7 +146,7 @@ class AIScore(base_ops.NaryOp):
146146
name: ClassVar[str] = "ai_score"
147147

148148
prompt_context: Tuple[str | None, ...]
149-
connection_id: str
149+
connection_id: str | None
150150

151151
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
152152
return dtypes.FLOAT_DTYPE
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.CLASSIFY(input => (`string_col`), categories => ['greeting', 'rejection']) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SCORE(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def test_ai_generate_double_with_model_param(
281281
snapshot.assert_match(sql, "out.sql")
282282

283283

284-
def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
284+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
285+
def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
285286
col_name = "string_col"
286287

287288
op = ops.AIIf(
288289
prompt_context=(None, " is the same as ", None),
289-
connection_id=CONNECTION_ID,
290+
connection_id=connection_id,
290291
)
291292

292293
sql = utils._apply_ops_to_sql(
@@ -296,26 +297,28 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
296297
snapshot.assert_match(sql, "out.sql")
297298

298299

299-
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
300+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
301+
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
300302
col_name = "string_col"
301303

302304
op = ops.AIClassify(
303305
prompt_context=(None,),
304306
categories=("greeting", "rejection"),
305-
connection_id=CONNECTION_ID,
307+
connection_id=connection_id,
306308
)
307309

308310
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
309311

310312
snapshot.assert_match(sql, "out.sql")
311313

312314

313-
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
315+
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
316+
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
314317
col_name = "string_col"
315318

316319
op = ops.AIScore(
317320
prompt_context=(None, " is the same as ", None),
318-
connection_id=CONNECTION_ID,
321+
connection_id=connection_id,
319322
)
320323

321324
sql = utils._apply_ops_to_sql(

0 commit comments

Comments
 (0)