Comment by gcr

Comment by gcr 8 days ago

3 replies

Thanks for such a cool project! It's immediately apparent how to use it and I appreciate the brief examples.

Quick question: In the breast cancer example from the README, simple support vector machines from sklearn (the first thing i tried to compare baseline performance, incidentally) seem to outperform TabPFN. Is this expected? I know it's a baseline to demonstrate ease of use rather than SOTA performance, but I am curious.

    # (TabPFN)
    In [13]: print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))
    ROC AUC: 0.996299494264216

    # (LinearSVC)
    In [27]: from sklearn.svm import LinearSVC
    
    In [28]: clf=LinearSVC(C=0.01).fit(X_train, y_train)
    
    In [29]: roc_auc_score(y_test, clf.decision_function(X_test))
    Out[29]: 0.997532996176144
noahho 8 days ago

Author here! The breast cancer dataset is simple and heavily saturated, so small differences between methods are expected. As you say, single-use examples can be noisy due to randomness in how the data is randomly split into training and testing sets especially for a saturated dataset like this one. Cross-validation reduces this variance by averaging over multiple splits. I just ran this below:

  TabPFN mean ROC AUC: 0.9973

  SVM mean ROC AUC: 0.9903

  TabPFN per split: [0.99737963 0.99639699 0.99966931 0.99338624 0.99966465]

  SVM per split: [0.99312152 0.98788077 0.99603175 0.98313492 0.99128102]

  from sklearn.model_selection import cross_val_score
  from tabpfn import TabPFNClassifier
  from sklearn.datasets import load_breast_cancer
  from sklearn.svm import LinearSVC
  import numpy as np

  data = load_breast_cancer()
  X, y = data.data, data.target

  # TabPFN
  tabpfn_clf = TabPFNClassifier()
  tabpfn_scores = cross_val_score(tabpfn_clf, X, y, cv=5, 
  scoring='roc_auc')
  print("TabPFN per split:", tabpfn_scores)
  print("TabPFN mean ROC AUC:", np.mean(tabpfn_scores))
  
  # SVM
  svm_clf = LinearSVC(C=0.01)
  svm_scores = cross_val_score(svm_clf, X, y, cv=5, 
  scoring='roc_auc')
  print("SVM per split:", svm_scores)
  print("SVM mean ROC AUC:", np.mean(svm_scores))
It's hard to communicate this properly, we should probably make sure to have a favourable example ready, but just included the simplest one!
  • gcr 8 days ago

    thanks, this is helpful!

    I certainly appreciate how the example in the README makes it instantly apparent how to use the code.